训练模型需要的数据文件有:

MNIST_data文件夹下的mnist_train、mnist_test、noisy_train、noisy_test。train文件夹下60000个图片,test下10000个图片

noisy_train、noisy_test下的图片加了椒盐噪声与原图序号对应

离线测试需要的数据文件有:

MNIST_data文件夹下的my_model.hdf5、my_test。my_test文件夹下要有一层嵌套文件夹并放测试图片

数据集准备参考:

https://www.cnblogs.com/dzzy/p/10824072.html

训练:

import os
import glob
from PIL import Image
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________") base_dir = 'MNIST_data' #基准目录
#os._exit(0)
Datapath = os.path.join(base_dir,'mnist_train/*.png') #train目录
x_train = np.zeros((60000, 28, 28))
x_train = np.reshape(x_train, (60000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train[i] = img
i += 1 Datapath = os.path.join(base_dir,'mnist_test/*.png') #test目录
x_test = np.zeros((10000, 28, 28))
x_test = np.reshape(x_test, (10000, 28, 28, 1))
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test[i] = img
i += 1 print( x_train.shape)
print( x_test.shape) Datapath = os.path.join(base_dir,'noisy_train/*.png') #test目录
x_train_noisy = np.zeros(x_train.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_train_noisy[i] = img
i += 1 Datapath = os.path.join(base_dir,'noisy_test/*.png') #test目录
x_test_noisy = np.zeros(x_test.shape)
i = 0
for imageFile in glob.glob(Datapath ):
# 打开图像并转化为数字矩阵
img = np.array(Image.open(imageFile))
img = np.reshape(img, (1, 28, 28, 1))
img = img.astype('float32') / 255.
x_test_noisy[i] = img
i += 1 print( x_train_noisy.shape)
print( x_test_noisy.shape) '''
plt.figure(figsize=(20, 4))
plt.subplot(4, 4, 1)
plt.imshow(x_train[0].reshape(28, 28))
plt.subplot(4, 4, 2)
plt.imshow(x_train_noisy[0].reshape(28, 28))
plt.subplot(4, 4, 3)
plt.imshow(x_train[1].reshape(28, 28))
plt.subplot(4, 4, 4)
plt.imshow(x_train_noisy[1].reshape(28, 28))
plt.show()
#os._exit(0)
''' """
搭建模型
"""
input_img = Input(shape=(28, 28, 1)) # adapt this if using `channels_first` image data format x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img) #relu激活函数
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x) # at this point the representation is (7, 7, 32) x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x) autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') file_path="MNIST_data/weights-improvement-{epoch:02d}-{val_loss:.2f}.hdf5"
tensorboard = TensorBoard(log_dir='/tmp/tb', histogram_freq=0, write_graph=False)
checkpoint = ModelCheckpoint(filepath=file_path,verbose=1,monitor='val_loss', save_weights_only=False,mode='auto' ,save_best_only=True,period=1)
autoencoder.fit(x_train_noisy, x_train,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=[checkpoint,tensorboard]) #展示结果
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
#noisy data
ax = plt.subplot(3, n, i+1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#predict
ax = plt.subplot(3, n, i+1+n)
decoded_img = autoencoder.predict(x_test_noisy)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
#original
ax = plt.subplot(3, n, i+1+2*n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()

测试:(同  https://www.cnblogs.com/dzzy/p/11387645.html

import os
import numpy as np
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.models import Model,Sequential,load_model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator,img_to_array, load_img
from keras.callbacks import TensorBoard , ModelCheckpoint
print("_________________________keras start_____________________________")
pic_num = 3
base_dir = 'MNIST_data' #基准目录
train_dir = os.path.join(base_dir,'my_test') #train目录
validation_dir="".join(train_dir)
test_datagen = ImageDataGenerator(rescale= 1./255)
validation_generator = test_datagen.flow_from_directory(validation_dir,
target_size = (28,28),
color_mode = "grayscale",
batch_size = pic_num,
class_mode = "categorical")#利用test_datagen.flow_from_directory(图像地址,目标size,批量数目,标签分类情况)
for x_train,batch_labels in validation_generator:
break
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
y_train = x_train # create model
model = load_model('MNIST_data/my_model.hdf5')
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file") # estimate accuracy on whole dataset using loaded weights
y_train=model.predict(x_train)
# 评价训练出的网络
#loss, accuracy = model.evaluate(x_train, y_train)
#print('test loss: ', loss)
#print('test accuracy: ', accuracy) n = pic_num
for i in range(n):
ax = plt.subplot(2, n, i+1)
plt.imshow(x_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i+1+n)
plt.imshow(y_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

Keras学习笔记三:一个图像去噪训练并离线测试的例子,基于mnist的更多相关文章

  1. ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心

    作者:Grey 原文地址:ZooKeeper学习笔记三:使用ZooKeeper实现一个简单的配置中心 前置知识 完成ZooKeeper集群搭建以及熟悉ZooKeeperAPI基本使用 需求 很多程序往 ...

  2. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  3. Oracle学习笔记三 SQL命令

    SQL简介 SQL 支持下列类别的命令: 1.数据定义语言(DDL) 2.数据操纵语言(DML) 3.事务控制语言(TCL) 4.数据控制语言(DCL)  

  4. ASP.NET MVC Web API 学习笔记---第一个Web API程序

    http://www.cnblogs.com/qingyuan/archive/2012/10/12/2720824.html GetListAll /api/Contact GetListBySex ...

  5. [读书笔记]C#学习笔记三: C#类型详解..

    前言 这次分享的主要内容有五个, 分别是值类型和引用类型, 装箱与拆箱,常量与变量,运算符重载,static字段和static构造函数. 后期的分享会针对于C#2.0 3.0 4.0 等新特性进行. ...

  6. [Firefly引擎][学习笔记三][已完结]所需模块封装

    原地址:http://www.9miao.com/question-15-54671.html 学习笔记一传送门学习笔记二传送门 学习笔记三导读:        笔记三主要就是各个模块的封装了,这里贴 ...

  7. JSP学习笔记(三):简单的Tomcat Web服务器

    注意:每次对Tomcat配置文件进行修改后,必须重启Tomcat 在E盘的DATA文件夹中创建TomcatDemo文件夹,并将Tomcat安装路径下的webapps/ROOT中的WEB-INF文件夹复 ...

  8. java之jvm学习笔记三(Class文件检验器)

    java之jvm学习笔记三(Class文件检验器) 前面的学习我们知道了class文件被类装载器所装载,但是在装载class文件之前或之后,class文件实际上还需要被校验,这就是今天的学习主题,cl ...

  9. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

随机推荐

  1. 使用WSAIoctl获取AcceptEx,Connectex,Getacceptexsockaddrs函数指针

    运行WinNT和Win2000的系统上,这些APIs在Microsoft提供的DLL(mswsock.dll)里实现,可以通过链接mswsock.lib或者通过WSAioctl的SIO_GET_EXT ...

  2. Eclipse 快捷键、文档注释、多行注释的快捷键

    关于快捷键 Eclipse 的很多操作都提供了快捷键功能,我们可以通过键盘就能很好的控制 Eclipse 各个功能: 一.多行注释快捷键 1.选中你要加注释的区域,用ctrl+shift+C 或者ct ...

  3. WinForm - 不用自绘实现仿QQ2013

    素材啥的都是一手整理的,绝对的原创.这是13年做的,虽然是个老项目了,可里面涉及的winform技术不会过时,所以就拿出来重温探讨下技术要点. 没使用任何自绘命令,可以说是非常容易理解与学习的. 效果 ...

  4. Delphi MSComm 控件方法

  5. 5、vim编辑器

    1.什么是VIM? 理解为windows下面的文本编辑器,比如记事本,比如word文档 2.为什么要学? 因为在后面我们配置的服务,都需要人为修改配置,以便让程序按照我们修改后的指示运行. 1.修改配 ...

  6. 02.Zabbix⾃定义监控项

    1.zabbix⾃定义监控初试 如何获取系统中想监控对象的值,获取后⼜如何将该值传递给Zabbix-Server 1.1.监控系统中的对象 #(系统监控命令 + awk + 筛选条件 = 监控的状态值 ...

  7. 第九章·Logstash深入-Logstash配合rsyslog收集haproxy日志

    rsyslog介绍及安装配置 在centos 6及之前的版本叫做syslog,centos 7开始叫做rsyslog,根据官方的介绍,rsyslog(2013年版本)可以达到每秒转发百万条日志的级别, ...

  8. Darknet的整体框架,安装,训练与测试

    目录 一.Darknet优势 二.Darknet的结构 三.Darknet安装 四.Darknet的训练 五.Darknet的检测 正文 一.Darknet优势 darknet是一个由纯C编写的深度学 ...

  9. kubernetes资源清单之pod

    什么是pod? Pod是一组一个或多个容器(例如Docker容器),具有共享的存储/网络,以及有关如何运行这些容器的规范. Pod的内容始终位于同一地点,并在同一时间安排,并在共享上下文中运行. Po ...

  10. three.js之让物体动起来方式(二)移动物体

    <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...