训练模型需要的数据文件有:

MNIST_data文件夹下的mnist_train、mnist_test、noisy_train、noisy_test。train文件夹下60000个图片,test下10000个图片

noisy_train、noisy_test下的图片加了椒盐噪声与原图序号对应

离线测试需要的数据文件有:

MNIST_data文件夹下的my_model.hdf5、my_test。my_test文件夹下要有一层嵌套文件夹并放测试图片

数据集准备参考:

https://www.cnblogs.com/dzzy/p/10824072.html

训练:

import os
import glob
from PIL import Image
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________") base_dir = 'MNIST_data' #基准目录
#os._exit(0)
Datapath = os.path.join(base_dir,'mnist_train/*.png') #train目录
x_train = np.zeros((60000, 28, 28))
x_train = np.reshape(x_train, (60000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train[i] = img
i += 1 Datapath = os.path.join(base_dir,'mnist_test/*.png') #test目录
x_test = np.zeros((10000, 28, 28))
x_test = np.reshape(x_test, (10000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test[i] = img
i += 1 print( x_train.shape)
print( x_test.shape) Datapath = os.path.join(base_dir,'noisy_train/*.png') #test目录
x_train_noisy = np.zeros(x_train.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train_noisy[i] = img
i += 1 Datapath = os.path.join(base_dir,'noisy_test/*.png') #test目录
x_test_noisy = np.zeros(x_test.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test_noisy[i] = img
i += 1 print( x_train_noisy.shape)
print( x_test_noisy.shape) '''
plt.figure(figsize=(20, 4))
plt.subplot(4, 4, 1)
plt.imshow(x_train[0].reshape(28, 28))
plt.subplot(4, 4, 2)
plt.imshow(x_train_noisy[0].reshape(28, 28))
plt.subplot(4, 4, 3)
plt.imshow(x_train[1].reshape(28, 28))
plt.subplot(4, 4, 4)
plt.imshow(x_train_noisy[1].reshape(28, 28))
plt.show()
#os._exit(0)
''' """
搭建模型
"""
input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) #relu激活函数
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x) # at this point the representation is (7, 7, 32) x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') file_path="MNIST_data/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='/tmp/tb', histogram_freq=0, write_graph=False)
checkpoint = ModelCheckpoint(filepath=file_path,verbose=1,monitor='val_loss', save_weights_only=False,mode='auto' ,save_best_only=True,period=1)
autoencoder.fit(x_train_noisy, x_train,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=[checkpoint,tensorboard]) #展示结果
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
#noisy data
ax = plt.subplot(3, n, i+1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#predict
ax = plt.subplot(3, n, i+1+n)
decoded_img = autoencoder.predict(x_test_noisy)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
#original
ax = plt.subplot(3, n, i+1+2*n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()

测试:(同  https://www.cnblogs.com/dzzy/p/11387645.html

import os
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model,Sequential,load_model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________")
pic_num = 3
base_dir = 'MNIST_data' #基准目录
train_dir = os.path.join(base_dir,'my_test') #train目录
validation_dir="".join(train_dir)
test_datagen = ImageDataGenerator(rescale= 1./255)
validation_generator = test_datagen.flow_from_directory(validation_dir,
target_size = (28,28),
color_mode = "grayscale",
batch_size = pic_num,
class_mode = "categorical")#利用test_datagen.flow_from_directory(图像地址,目标size,批量数目,标签分类情况)
for x_train,batch_labels in validation_generator:
break
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
y_train = x_train # create model
model = load_model('MNIST_data/my_model.hdf5')
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file") # estimate accuracy on whole dataset using loaded weights
y_train=model.predict(x_train)
# 评价训练出的网络
#loss, accuracy = model.evaluate(x_train, y_train)
#print('test loss: ', loss)
#print('test accuracy: ', accuracy) n = pic_num
for i in range(n):
ax = plt.subplot(2, n, i+1)
plt.imshow(x_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i+1+n)
plt.imshow(y_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

Keras学习笔记三:一个图像去噪训练并离线测试的例子,基于mnist的更多相关文章

  1. ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心

    作者:Grey 原文地址:ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心 前置知识 完成ZooKeeper集群搭建以及熟悉ZooKeeperAPI基本使用 需求 很多程序往 ...

  2. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  3. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

  4. ASP.NET MVC Web API 学习笔记---第一个Web API程序

    http://www.cnblogs.com/qingyuan/archive/2012/10/12/2720824.html GetListAll /api/Contact GetListBySex ...

  5. [读书笔记]C#学习笔记三: C#类型详解..

    前言 这次分享的主要内容有五个, 分别是值类型和引用类型, 装箱与拆箱,常量与变量,运算符重载,static字段和static构造函数. 后期的分享会针对于C#2.0 3.0 4.0 等新特性进行. ...

  6. [Firefly引擎][学习笔记三][已完结]所需模块封装

    原地址:http://www.9miao.com/question-15-54671.html 学习笔记一传送门学习笔记二传送门 学习笔记三导读:        笔记三主要就是各个模块的封装了,这里贴 ...

  7. JSP学习笔记(三):简单的Tomcat Web服务器

    注意:每次对Tomcat配置文件进行修改后,必须重启Tomcat 在E盘的DATA文件夹中创建TomcatDemo文件夹,并将Tomcat安装路径下的webapps/ROOT中的WEB-INF文件夹复 ...

  8. java之jvm学习笔记三(Class文件检验器)

    java之jvm学习笔记三(Class文件检验器) 前面的学习我们知道了class文件被类装载器所装载,但是在装载class文件之前或之后,class文件实际上还需要被校验,这就是今天的学习主题,cl ...

  9. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

随机推荐

  1. 【原创】大叔经验分享(62)kudu副本数量

    kudu的副本数量是在表上设置,可以通过命令查看 # sudo -u kudu kudu cluster ksck $master ... Summary by table Name | RF | S ...

  2. BZOJ4241历史研究题解--回滚莫队

    题目链接 https://www.lydsy.com/JudgeOnline/problem.php?id=4241 分析 这题就是求区间权值乘以权值出现次数的最大值,一看莫队法块可搞,但仔细想想,莫 ...

  3. javascript相关的增删改查以及this的理解

    前两天做了一个有关表单增删改查的例子,现在贴出来.主要是想好好说一下this. 下面贴一张我要做的表格效果. 就是实现简单的一个增删改查. 1.点击增加后自动增加一行: 2.点击保存当前行会将属性改成 ...

  4. java9 新特征

    Java 平台级模块系统 java模块化解决的问题:减少Java应用和Java核心运行时环境的大小与复杂性 模块化的 JAR 文件都包含一个额外的模块描述器.在这个模块描述器中, 对其它模块的依赖是通 ...

  5. C++ void*解惑

    最近遇到void *的问题无法解决,发现再也无法逃避了(以前都是采取悄悄绕过原则),于是我决定直面它. 在哪遇到了? 线程创建函数pthread_create()的最后一个参数void *arg,嗯? ...

  6. TCP坚持定时器

    窗口探查(window probe) 当接收方TCP缓冲区没有剩余空间后,在ACK中会通知发送方window=0,此时发送方就暂停发送数据.当接收方TCP缓冲区又有空间后,会再次发送一个ACK,告知其 ...

  7. 获取iframe子页面内容高度给iframe动态设置高度

    <!DOCTYPE html><html> <head> <meta charset="UTF-8" /> <meta nam ...

  8. ngnix 配置说明

    #定义Nginx运行的用户和用户组 user www www; # #nginx进程数,建议设置为等于CPU总核心数. worker_processes ; # #全局错误日志定义类型,[ debug ...

  9. PAT Basic 1044 火星数字 (20 分)

    火星人是以 进制计数的: 地球人的 被火星人称为 tret. 地球人数字 到 的火星文分别为:jan, feb, mar, apr, may, jun, jly, aug, sep, oct, nov ...

  10. php is_numeric函数可绕过产生SQL注入

    老老实实mysql_real_escape_string()防作死......is_numeric的SQL利用条件虽然有点苛刻,但还是少用的好= = 某CTF中亦有实测案例,请戳 http://dro ...