训练模型需要的数据文件有:

MNIST_data文件夹下的mnist_train、mnist_test、noisy_train、noisy_test。train文件夹下60000个图片,test下10000个图片

noisy_train、noisy_test下的图片加了椒盐噪声与原图序号对应

离线测试需要的数据文件有:

MNIST_data文件夹下的my_model.hdf5、my_test。my_test文件夹下要有一层嵌套文件夹并放测试图片

数据集准备参考:

https://www.cnblogs.com/dzzy/p/10824072.html

训练:

import os
import glob
from PIL import Image
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________") base_dir = 'MNIST_data' #基准目录
#os._exit(0)
Datapath = os.path.join(base_dir,'mnist_train/*.png') #train目录
x_train = np.zeros((60000, 28, 28))
x_train = np.reshape(x_train, (60000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train[i] = img
i += 1 Datapath = os.path.join(base_dir,'mnist_test/*.png') #test目录
x_test = np.zeros((10000, 28, 28))
x_test = np.reshape(x_test, (10000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test[i] = img
i += 1 print( x_train.shape)
print( x_test.shape) Datapath = os.path.join(base_dir,'noisy_train/*.png') #test目录
x_train_noisy = np.zeros(x_train.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train_noisy[i] = img
i += 1 Datapath = os.path.join(base_dir,'noisy_test/*.png') #test目录
x_test_noisy = np.zeros(x_test.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test_noisy[i] = img
i += 1 print( x_train_noisy.shape)
print( x_test_noisy.shape) '''
plt.figure(figsize=(20, 4))
plt.subplot(4, 4, 1)
plt.imshow(x_train[0].reshape(28, 28))
plt.subplot(4, 4, 2)
plt.imshow(x_train_noisy[0].reshape(28, 28))
plt.subplot(4, 4, 3)
plt.imshow(x_train[1].reshape(28, 28))
plt.subplot(4, 4, 4)
plt.imshow(x_train_noisy[1].reshape(28, 28))
plt.show()
#os._exit(0)
''' """
搭建模型
"""
input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) #relu激活函数
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x) # at this point the representation is (7, 7, 32) x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') file_path="MNIST_data/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='/tmp/tb', histogram_freq=0, write_graph=False)
checkpoint = ModelCheckpoint(filepath=file_path,verbose=1,monitor='val_loss', save_weights_only=False,mode='auto' ,save_best_only=True,period=1)
autoencoder.fit(x_train_noisy, x_train,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=[checkpoint,tensorboard]) #展示结果
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
#noisy data
ax = plt.subplot(3, n, i+1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#predict
ax = plt.subplot(3, n, i+1+n)
decoded_img = autoencoder.predict(x_test_noisy)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
#original
ax = plt.subplot(3, n, i+1+2*n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()

测试:(同  https://www.cnblogs.com/dzzy/p/11387645.html

import os
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model,Sequential,load_model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________")
pic_num = 3
base_dir = 'MNIST_data' #基准目录
train_dir = os.path.join(base_dir,'my_test') #train目录
validation_dir="".join(train_dir)
test_datagen = ImageDataGenerator(rescale= 1./255)
validation_generator = test_datagen.flow_from_directory(validation_dir,
target_size = (28,28),
color_mode = "grayscale",
batch_size = pic_num,
class_mode = "categorical")#利用test_datagen.flow_from_directory(图像地址,目标size,批量数目,标签分类情况)
for x_train,batch_labels in validation_generator:
break
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
y_train = x_train # create model
model = load_model('MNIST_data/my_model.hdf5')
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file") # estimate accuracy on whole dataset using loaded weights
y_train=model.predict(x_train)
# 评价训练出的网络
#loss, accuracy = model.evaluate(x_train, y_train)
#print('test loss: ', loss)
#print('test accuracy: ', accuracy) n = pic_num
for i in range(n):
ax = plt.subplot(2, n, i+1)
plt.imshow(x_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i+1+n)
plt.imshow(y_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

Keras学习笔记三:一个图像去噪训练并离线测试的例子,基于mnist的更多相关文章

  1. ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心

    作者:Grey 原文地址:ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心 前置知识 完成ZooKeeper集群搭建以及熟悉ZooKeeperAPI基本使用 需求 很多程序往 ...

  2. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  3. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

  4. ASP.NET MVC Web API 学习笔记---第一个Web API程序

    http://www.cnblogs.com/qingyuan/archive/2012/10/12/2720824.html GetListAll /api/Contact GetListBySex ...

  5. [读书笔记]C#学习笔记三: C#类型详解..

    前言 这次分享的主要内容有五个, 分别是值类型和引用类型, 装箱与拆箱,常量与变量,运算符重载,static字段和static构造函数. 后期的分享会针对于C#2.0 3.0 4.0 等新特性进行. ...

  6. [Firefly引擎][学习笔记三][已完结]所需模块封装

    原地址:http://www.9miao.com/question-15-54671.html 学习笔记一传送门学习笔记二传送门 学习笔记三导读:        笔记三主要就是各个模块的封装了,这里贴 ...

  7. JSP学习笔记(三):简单的Tomcat Web服务器

    注意:每次对Tomcat配置文件进行修改后,必须重启Tomcat 在E盘的DATA文件夹中创建TomcatDemo文件夹,并将Tomcat安装路径下的webapps/ROOT中的WEB-INF文件夹复 ...

  8. java之jvm学习笔记三(Class文件检验器)

    java之jvm学习笔记三(Class文件检验器) 前面的学习我们知道了class文件被类装载器所装载,但是在装载class文件之前或之后,class文件实际上还需要被校验,这就是今天的学习主题,cl ...

  9. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

随机推荐

  1. JS基础_嵌套的for循环

    <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <title&g ...

  2. JavaScript例子3-对多选框进行操作,输出选中的多选框的个数

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...

  3. python之字符串类型及其操作

    1.1字符串类型的表示 字符串是字符的序列表示,可以由一对单引号('). 双引号(")或三引号(’")构成.其中,单引号和双引号都可以表示单行字符串,两者作用相同.使用单引号时,双 ...

  4. Eclipse 设置新建文件默认编码为 utf-8 的方法

    选择编辑器顶部 Windows->Preferences->搜索jsp->选择utf-8编码->保存.

  5. 04 Go语言之包

    1.为什么有包这个概念? 1)开发中,往往要在不同的文件中调用其他文件的函数 2)Go代码最小粒度单位是”包” 3)go的每一个文件都属于一个包,通过package管理 4)go以包的形式管理文件和项 ...

  6. java 将一个正整数翻译成人民币大写的读法

    程序如下: import java.lang.StringBuffer; /** 给定一个浮点数,将其装换成人民币大写的读法 88.5:捌十捌元零伍角 */ public class Num2Rmb ...

  7. golang 环境变量讲解

    以下配置以MAC 下配置为例,但其他环境下大同小异. GOROOT就是go的安装路径在~/.bash_profile中添加下面语句: GOROOT=/usr/local/go export GOROO ...

  8. XDCTF2014 Writeup

    Web50 猜谜语类题目?FLAG在图片中有一些字符的 ASCii值,拼起来就是FLAG. Web100 隐写术.使用工具 StegSolve,把任一颜色的bit0拼起来图片的最开始部分即为  fla ...

  9. Linux系统安装常用开发软件

    vim.jdk.tomcat.mysql 安装vim(命令模式=>编辑模式=>底行模式) [root@localhost ~]# yum install vim*结束后一直确认即可,键入y ...

  10. SPOJ 1825 经过不超过K个黑点的树上最长路径 点分治

    每一次枚举到重心 按子树中的黑点数SORT一下 启发式合并 #include<cstdio> #include<cstring> #include<algorithm&g ...