Tensorflow2.0-mnist手写数字识别示例

     

      读书不觉春已深,一寸光阴一寸金。

简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0、1、2、4。

                 

一、mnist数据集准备

虽然可以通过代码自动下载数据集,但是mnist 数据集国内下载不稳定,会出现【Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz】的情况,代码从定义目录data_set_tf3 中未获取到mnist 数据集就会自动下载,但下载时间比较久,还是提前准备好。

Downloading mnist data from https

mnist数据集下载地址

mnist数据集官网如上,下载下面四个东西就可以了,图中标红的两个images和lables。

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所,  训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 的工作人员;测试集(test set) 也是同样比例的手写数字数据;可以新建一个文件夹 – mnist, 将数据集下载到 mnist 解压即可。

mnist数据集整合

三、图片训练

train.py 训练代码如下:

 1 import os
2 import tensorflow as tf
3 from tensorflow.keras import datasets, layers, models
4
5 '''
6 python 3.7、3.9
7 tensorflow 2.0.0b0
8 '''
9
10 # 模型定义的前半部分主要使用Keras.layers 提供的Conv2D(卷积)与MaxPooling2D(池化)函数。
11 # CNN的输入是维度为(image_height, image_width, color_channels)的张量,
12 # mnist数据集是黑白的,因此只有一个color_channels 颜色通道;一般的彩色图片有3个(R, G, B),
13 # 也有4个通道的(R, G, B, A),A代表透明度;
14 # 对于mnist数据集,输入的张量维度为(28, 28, 1),通过参数input_shapa 传给网络的第一层
15 # CNN模型处理:
16 class CNN(object):
17 def __init__(self):
18 model = models.Sequential()
19 # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
20 model.add(layers.Conv2D(
21 32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
22 model.add(layers.MaxPooling2D((2, 2)))
23 # 第2层卷积,卷积核大小为3*3,64个
24 model.add(layers.Conv2D(64, (3, 3), activation='relu')) # 使用神经网络中激活函数ReLu
25 model.add(layers.MaxPooling2D((2, 2)))
26 # 第3层卷积,卷积核大小为3*3,64个
27 model.add(layers.Conv2D(64, (3, 3), activation='relu'))
28
29 model.add(layers.Flatten())
30 model.add(layers.Dense(64, activation='relu'))
31 model.add(layers.Dense(10, activation='softmax'))
32 # Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小
33 # dense :全连接层相当于添加一个层
34 # softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!
35 model.summary() # 输出模型各层的参数状况
36
37 self.model = model
38
39
40 # mnist数据集预处理
41 class DataSource(object):
42 def __init__(self):
43 # mnist数据集存储的位置,如果不存在将自动下载
44 data_path = os.path.abspath(os.path.dirname(
45 __file__)) + '/../data_set_tf2/mnist.npz'
46 (train_images, train_labels), (test_images,
47 test_labels) = datasets.mnist.load_data(path=data_path)
48 # 6万张训练图片,1万张测试图片
49 train_images = train_images.reshape((60000, 28, 28, 1))
50 test_images = test_images.reshape((10000, 28, 28, 1))
51 # 像素值映射到 0 - 1 之间
52 train_images, test_images = train_images / 255.0, test_images / 255.0
53
54 self.train_images, self.train_labels = train_images, train_labels
55 self.test_images, self.test_labels = test_images, test_labels
56
57
58 # 开始训练并保存训练结果
59 class Train:
60 def __init__(self):
61 self.cnn = CNN()
62 self.data = DataSource()
63
64 def train(self):
65 check_path = './ckpt/cp-{epoch:04d}.ckpt'
66 # period 每隔5epoch保存一次
67 save_model_cb = tf.keras.callbacks.ModelCheckpoint(
68 check_path, save_weights_only=True, verbose=1, period=5)
69
70 self.cnn.model.compile(optimizer='adam',
71 loss='sparse_categorical_crossentropy',
72 metrics=['accuracy'])
73 self.cnn.model.fit(self.data.train_images, self.data.train_labels,
74 epochs=5, callbacks=[save_model_cb])
75
76 test_loss, test_acc = self.cnn.model.evaluate(
77 self.data.test_images, self.data.test_labels)
78 print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
79
80
81 if __name__ == "__main__":
82 app = Train()
83 app.train()

~拍一拍小轮胎

mnist手写数字识别训练了四分钟左右,准确率高达0.9902,下面的视频只截取了训练的前十秒。

 mnist手写数字识别训练视频

model.summary()打印定义的模型结构

CNN定义的模型结构

 1 Model: "sequential"
2 _________________________________________________________________
3 Layer (type) Output Shape Param #
4 =================================================================
5 conv2d (Conv2D) (None, 26, 26, 32) 320
6 _________________________________________________________________
7 max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
8 _________________________________________________________________
9 conv2d_1 (Conv2D) (None, 11, 11, 64) 18496
10 _________________________________________________________________
11 max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0
12 _________________________________________________________________
13 conv2d_2 (Conv2D) (None, 3, 3, 64) 36928
14 _________________________________________________________________
15 flatten (Flatten) (None, 576) 0
16 _________________________________________________________________
17 dense (Dense) (None, 64) 36928
18 _________________________________________________________________
19 dense_1 (Dense) (None, 10) 650
20 =================================================================
21 Total params: 93,322
22 Trainable params: 93,322
23 Non-trainable params: 0
24 _________________________________________________________________

我们可以看到,每一个Conv2D 和MaxPooling2D 层的输出都是一个三维的张量(height, width, channels),height 和width 会逐渐地变小;输出的channel 的个数,是由第一个参数(例如,32或64)控制的,随着height 和width 的变小,channel可以变大(从算力的角度)。

模型的后半部分,是定义张量的输出。layers.Flatten 会将三维的张量转为一维的向量,展开前张量的维度是(3, 3, 64) ,转为一维(576)【3*3*64】的向量后,紧接着使用layers.Dense 层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。最大值的下标即可代表对应的数字,使用numpy 的argmax() 方法获取最大值下标,很容易计算得到预测值。

train.py运行结果

可以看到,在第一轮训练后,识别准确率达到了0.9536,五轮训练之后,使用测试集验证,准确率达到了0.9902。在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt,而且此时准确率为更高的0.9940,所以也并不是训练时间次数越久越好,过犹不及。可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测。

保存训练模型参数

四、图片预测

predict.py代码如下:

 1 import tensorflow as tf
2 from PIL import Image
3 import numpy as np
4
5 from mnist.v4_cnn.train import CNN
6
7 '''
8 python 3.7 3.9
9 tensorflow 2.0.0b0
10 pillow(PIL) 4.3.0
11 '''
12
13
14 class Predict(object):
15 def __init__(self):
16 latest = tf.train.latest_checkpoint('./ckpt')
17 self.cnn = CNN()
18 # 恢复网络权重
19 self.cnn.model.load_weights(latest)
20
21 def predict(self, image_path):
22 # 以黑白方式读取图片
23 img = Image.open(image_path).convert('L')
24 img = np.reshape(img, (28, 28, 1)) / 255.
25 x = np.array([1 - img])
26
27 # API refer: https://keras.io/models/model/
28 y = self.cnn.model.predict(x)
29
30 # 因为x只传入了一张图片,取y[0]即可
31 # np.argmax()取得最大值的下标,即代表的数字
32 print(image_path)
33 print(y[0])
34 print(' -> Predict picture number is: ', np.argmax(y[0]))
35
36
37 if __name__ == "__main__":
38 app = Predict()
39 app.predict('../test_images/0.png')
40 app.predict('../test_images/1.png')
41 app.predict('../test_images/4.png')
42 app.predict('../test_images/2.png')

预测结果

 预测结果:

 1 ../test_images/0.png
2 [9.9999774e-01 2.6819215e-08 1.2541744e-07 8.7437911e-08 1.0661940e-09
3 3.3693670e-08 4.6488995e-07 3.5915035e-09 9.8040758e-08 1.4385278e-06]
4 -> Predict picture number is: 0
5 ../test_images/1.png
6 [7.75440956e-09 9.99991298e-01 1.41642090e-07 1.09819875e-10
7 6.76554646e-06 7.63710162e-09 2.37024622e-08 1.58189516e-06
8 2.49125264e-07 4.92376007e-09]
9 -> Predict picture number is: 1
10 ../test_images/4.png
11 [7.03467840e-10 8.20740708e-04 1.11648405e-04 3.93262711e-09
12 9.99048650e-01 1.08713095e-07 4.24647197e-08 1.85665340e-05
13 5.03181887e-08 1.86591734e-07]
14 -> Predict picture number is: 4
15 ../test_images/2.png
16 [1.5828672e-08 1.9245699e-07 9.9999440e-01 5.3448480e-06 1.7397912e-10
17 8.6148493e-13 2.5441890e-10 5.3953073e-08 3.5735226e-08 8.9734775e-11]
18 -> Predict picture number is: 2

如上,经CNN训练后通过模型参数准确预测出了0、1、2、4四张手写图片的真实值。

                

    

 读书不觉春已深

                            一寸光阴一寸金

Tensorflow2.0-mnist手写数字识别示例的更多相关文章

  1. [TensorFow2.0]-MNIST手写数字识别

    本人人工智能初学者,现在在学习TensorFlow2.0,对一些学习内容做一下笔记.笔记中,有些内容理解可能较为肤浅.有偏差等,各位在阅读时如有发现问题,请评论或者邮箱(右侧边栏有邮箱地址)提醒. 若 ...

  2. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  3. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  6. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

  7. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  8. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  9. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

随机推荐

  1. 自动化运维工具之Puppet常用资源(二)

    前文我们了解了部分puppet的资源的使用,以及资源和资源的依赖关系的定义,回顾请参考https://www.cnblogs.com/qiuhom-1874/p/14071459.html:今天我们继 ...

  2. 老猿学5G:融合计费场景的离线计费会话的Nchf_OfflineOnlyCharging_Release释放操作

    ☞ ░ 前往老猿Python博文目录 ░ 一.Nchf_OfflineOnlyCharging_Release消息交互流程 Nchf_OfflineOnlyCharging_Release是CHF提供 ...

  3. 搭建ARL资产安全灯塔

    老年人了,只能靠安装部署项目混混日子这样~ 简介: 斗象TCC团队正式发布「ARL资产安全灯塔」开源版,该项目现已上线开源社区GitHub.ARL旨在快速侦察与目标关联的互联网资产,构建基础资产信息库 ...

  4. 1、tensorflow 框架理解

    2020/10/31 参考:https://blog.csdn.net/mzpmzk/article/details/78636127 1. 两大步骤:定义图define the graph, 进行计 ...

  5. Java基础学习之流程控制语句(5)

    目录 1.顺序结构 2.选择结构 2.1.if else结构 2.2.switch case结构 3.循环结构 3.1.while结构 3.2.do while结构 3.3.for结构 3.3.1.普 ...

  6. Web前端-按钮点击效果(水波纹)

    这种效果可以由元素内嵌套canves实现,也可以由css3实现. Canves实现 网上摘了一份canves实现的代码,略微去掉了些重复定义的样式并且给出js注释,代码如下 第一种方法: html骨架 ...

  7. javascript常用继承方式.

      //原型链继承 function Parent() { this.name = 'per'; } function Child() { this.age = 20; } Child.prototy ...

  8. SNOI2020 部分题解

    D1T1 画图可以发现,多了一条边过后的图是串并联图.(暂时不确定) 然后我们考虑把问题变成,若生成树包含一条边\(e\),则使生成树权值乘上\(a_e\),否则乘上\(b_e\),求最终的生成树权值 ...

  9. Java集合源码分析(一)——集合框架

    集合框架 集合框架如图所示 Java集合是Java提供的工具包,主要包括常用的数据结构,包括:集合.链表.队列.栈.数组.映射等. 集合的工具包位置是java.util.* 集合主要可以分为五类: L ...

  10. Docker安装RabbitMQ与Kafka

    RabbitMq安装(dokcer) 下载镜像 docker pull rabbitmq 创建并启动容器 docker run -d --name rabbitmq -p 5672:5672 -p 1 ...