图片样本可视化

原文第四篇中,我们介绍了官方的入门案例MNIST,功能是识别手写的数字0-9。这是一个非常基础的TensorFlow应用,地位相当于通常语言学习的"Hello World!"。

我们先不进入TensorFlow 2.0中的MNIST代码讲解,因为TensorFlow 2.0在Keras的帮助下抽象度比较高,代码非常简单。但这也使得大量的工作被隐藏掉,反而让人难以真正理解来龙去脉。特别是其中所使用的样本数据也已经不同,而这对于学习者,是非常重要的部分。模型可以看论文、在网上找成熟的成果,数据的收集和处理,可不会有人帮忙。

在原文中,我们首先介绍了MNIST的数据结构,并且用一个小程序,把样本中的数组数据转换为JPG图片,来帮助读者理解原始数据的组织方式。

这里我们把小程序也升级一下,直接把图片显示在屏幕上,不再另外保存JPG文件。这样图片看起来更快更直观。

在TensorFlow 1.x中,是使用程序input_data.py来下载和管理MNIST的样本数据集。当前官方仓库的master分支中已经取消了这个代码,为了不去翻仓库,你可以在这里下载,放置到你的工作目录。

在TensorFlow 2.0中,会有keras.datasets类来管理大部分的演示和模型中需要使用的数据集,这个我们后面再讲。

MNIST的样本数据来自Yann LeCun的项目网站。如果网速比较慢的话,可以先用下载工具下载,然后放置到自己设置的数据目录,比如工作目录下的data文件夹,input_data检测到已有数据的话,不会重复下载。

下面是我们升级后显示训练样本集的源码,代码的讲解保留在注释中。如果阅读有疑问的,建议先去原文中看一下样本集数据结构的图示部分:

  1. #!/usr/bin/env python3
  2. # 引入mnist数据预读准备库
  3. # 2.0之后建议直接使用官方的keras.datasets.mnist.load_data
  4. # 此处为了同以前的讲解对比,沿用之前的引用文件
  5. import input_data
  6. # tensorflow 2.0库
  7. import tensorflow as tf
  8. # 引入绘图库
  9. import matplotlib.pyplot as plt
  10. # 这里使用mnist数据预读准备库检查给定路径是已经有样本数据,
  11. # 没有的话去网上下载,并保存在指定目录
  12. # 已经下载了数据的话,将数据读入内存,保存到mnist对象中
  13. mnist = input_data.read_data_sets("data/", one_hot=True)
  14. # 样本集的结构如下:
  15. # mnist.train 训练数据集
  16. # mnist.validation 验证数据集
  17. # mnist.test 测试数据集
  18. # len(mnist.train.images)=55000
  19. # len(mnist.train.images[0])=784
  20. # len(mnist.train.labels[0])=10
  21. def plot_image(i, imgs, labels):
  22. # 将1维的0-1的数据转换为标准的0-255的整数数据,2维28x28的图片
  23. image = tf.floor(256.0 * tf.reshape(imgs[i], [28, 28]))
  24. # 原数据为float,转换为uint8字节数据
  25. image = tf.cast(image, dtype=tf.uint8)
  26. # 标签样本为10个字节的数组,为1的元素下标就是样本的标签值
  27. # 这里使用argmax方法直接转换为0-9的整数
  28. label = tf.argmax(labels[i])
  29. plt.grid(False)
  30. plt.xticks([])
  31. plt.yticks([])
  32. # 绘制样本图
  33. plt.imshow(image)
  34. # 显示标签值
  35. plt.xlabel("{}".format(label))
  36. def show_images(num_rows, num_cols, images, labels):
  37. num_images = num_rows*num_cols
  38. plt.figure('Train Samples', figsize=(2*num_cols, 2*num_rows))
  39. # 循环显示前num_rows*num_cols副样本图片
  40. for i in range(num_images):
  41. plt.subplot(num_rows, num_cols, i+1)
  42. plot_image(i, images, labels)
  43. plt.show()
  44. # 显示前4*6=24副训练集样本
  45. show_images(4, 6, mnist.train.images, mnist.train.labels)

注意这个代码只是用来把样本集可视化。TensorFlow 2.0新特征,在这里只体现了取消Session和Session.run()。目的只是为了延续原来的讲解,让图片直接显示而不是保存为图像文件,以及升级到Python3和TensorFlow 2.0的执行环境。

样本集显示出来效果是这样的:

TensorFlow 2.0中的模型构建

原文第四篇中,我们使用了一个并不实用的线性回归模型来做手写数字识别。这样做可以简化中间层,从而能够使用可视化的手段来讲解机器视觉在数学上的基本原理。因为线性回归模型我们在本系列第一篇中讲过了,这里就跳过,直接说使用神经网络来解决MNIST问题。

神经网络模型的构建在TensorFlow 1.0中是最繁琐的工作。我们曾经为了讲解vgg-19神经网络的使用,首先编写了一个复杂的辅助类,用于从字符串数组的遍历中自动构建复杂的神经网络模型。

而在TensorFlow 2.0中,通过高度抽象的keras,可以非常容易的构建神经网络模型。

为了帮助理解,我们先把TensorFlow 1.0中使用神经网络解决MNIST问题的代码原文粘贴如下:

  1. #!/usr/bin/env python
  2. # -*- coding=UTF-8 -*-
  3. import input_data
  4. mnist = input_data.read_data_sets('data/', one_hot=True)
  5. import tensorflow as tf
  6. sess = tf.InteractiveSession()
  7. #对W/b做初始化有利于防止算法陷入局部最优解,
  8. #文档上讲是为了打破对称性和防止0梯度及神经元节点恒为0等问题,数学原理是类似问题
  9. #这两个初始化单独定义成子程序是因为多层神经网络会有多次调用
  10. def weight_variable(shape):
  11. #填充“权重”矩阵,其中的元素符合截断正态分布
  12. #可以有参数mean表示指定均值及stddev指定标准差
  13. initial = tf.truncated_normal(shape, stddev=0.1)
  14. return tf.Variable(initial)
  15. def bias_variable(shape):
  16. #用0.1常量填充“偏移量”矩阵
  17. initial = tf.constant(0.1, shape=shape)
  18. return tf.Variable(initial)
  19. #定义占位符,相当于tensorFlow的运行参数,
  20. #x是输入的图片矩阵,y_是给定的标注标签,有标注一定是监督学习
  21. x = tf.placeholder("float", shape=[None, 784])
  22. y_ = tf.placeholder("float", shape=[None, 10])
  23. #定义输入层神经网络,有784个节点,1024个输出,
  24. #输出的数量是自己定义的,要跟第二层节点的数量吻合
  25. W1 = weight_variable([784, 1024])
  26. b1 = bias_variable([1024])
  27. #使用relu算法的激活函数,后面的公式跟前一个例子相同
  28. h1 = tf.nn.relu(tf.matmul(x, W1) + b1)
  29. #定义第二层(隐藏层)网络,1024输入,512输出
  30. W2 = weight_variable([1024, 512])
  31. b2 = bias_variable([512])
  32. h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
  33. #定义第三层(输出层),512输入,10输出,10也是我们希望的分类数量
  34. W3 = weight_variable([512, 10])
  35. b3 = bias_variable([10])
  36. #最后一层的输出同样用softmax分类(也算是激活函数吧)
  37. y3=tf.nn.softmax(tf.matmul(h2, W3) + b3)
  38. #交叉熵代价函数
  39. cross_entropy = -tf.reduce_sum(y_*tf.log(y3))
  40. #这里使用了更加复杂的ADAM优化器来做"梯度最速下降",
  41. #前一个例子中我们使用的是:GradientDescentOptimizer
  42. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  43. #计算正确率以评估效果
  44. correct_prediction = tf.equal(tf.argmax(y3,1), tf.argmax(y_,1))
  45. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  46. #tf初始化及所有变量初始化
  47. sess.run(tf.global_variables_initializer())
  48. #进行20000步的训练
  49. for i in range(20000):
  50. #每批数据50组
  51. batch = mnist.train.next_batch(50)
  52. #每100步进行一次正确率计算并显示中间结果
  53. if i%100 == 0:
  54. train_accuracy = accuracy.eval(feed_dict={
  55. x:batch[0], y_: batch[1]})
  56. print "step %d, training accuracy %g"%(i, train_accuracy)
  57. #使用数据集进行训练
  58. train_step.run(feed_dict={x: batch[0], y_: batch[1]})
  59. #完成模型训练给出最终的评估结果
  60. print "test accuracy %g"%accuracy.eval(feed_dict={
  61. x: mnist.test.images, y_: mnist.test.labels})

总结一下上面TensorFlow 1.x版本MNIST代码中的工作:

  • 使用了一个三层的神经网络,每一层都使用重复性的代码构建
  • 每一层的代码中,要精心计算输入和输出数据的格式、维度,使得每一层同上、下两层完全吻合
  • 精心设计损失函数(代价函数)和选择回归算法
  • 复杂的训练循环

如果你理解了我总结的这几点,请继续看TensorFlow 2.0的实现:

  1. #!/usr/bin/env python3
  2. # 引入mnist数据预读准备库
  3. # 2.0之后建议直接使用官方的keras.datasets.mnist.load_data
  4. # 此处为了同以前的讲解对比,沿用之前的引用文件
  5. import input_data
  6. # tensorflow库
  7. import tensorflow as tf
  8. # tensorflow 已经内置了keras
  9. from tensorflow import keras
  10. # 引入绘图库
  11. import matplotlib.pyplot as plt
  12. # 这里使用mnist数据预读准备库检查给定路径是已经有样本数据,
  13. # 没有的话去网上下载,并保存在指定目录
  14. # 已经下载了数据的话,将数据读入内存,保存到mnist对象中
  15. mnist = input_data.read_data_sets("data/", one_hot=True)
  16. # 样本集的结构如下:
  17. # mnist.train 训练数据集
  18. # mnist.validation 验证数据集
  19. # mnist.test 测试数据集
  20. # len(mnist.train.images)=55000
  21. # len(mnist.train.images[0])=784
  22. # len(mnist.train.labels[0])=10
  23. def plot_image(i, imgs, labels, predictions):
  24. # 将1维的0-1的数据转换为标准的0-255的整数数据,2维28x28的图片
  25. image = tf.floor(256.0 * tf.reshape(imgs[i], [28, 28]))
  26. # 原数据为float,转换为uint8字节数据
  27. image = tf.cast(image, dtype=tf.uint8)
  28. # 标签样本为10个字节的数组,为1的元素下标就是样本的标签值
  29. # 这里使用argmax方法直接转换为0-9的整数
  30. label = tf.argmax(labels[i])
  31. prediction = tf.argmax(predictions[i])
  32. plt.grid(False)
  33. plt.xticks([])
  34. plt.yticks([])
  35. # 绘制样本图
  36. plt.imshow(image)
  37. # 显示标签值,对比显示预测值和实际标签值
  38. plt.xlabel("predict:{} label:{}".format(prediction, label))
  39. def show_images(num_rows, num_cols, images, labels, predictions):
  40. num_images = num_rows*num_cols
  41. plt.figure('Predict Samples', figsize=(2*num_cols, 2*num_rows))
  42. # 循环显示前num_rows*num_cols副样本图片
  43. for i in range(num_images):
  44. plt.subplot(num_rows, num_cols, i+1)
  45. plot_image(i, images, labels, predictions)
  46. plt.show()
  47. # 原文中已经说明了,当前是10个元素数组表示一个数字,
  48. # 值为1的那一元素的索引就是代表的数字,这是分类算法决定的
  49. # 下面是直接转换为0-9的正整数,用作训练的标签
  50. train_labels = tf.argmax(mnist.train.labels, 1)
  51. # 定义神经网络模型
  52. model = keras.Sequential([
  53. # 输入层为28x28共784个元素的数组,节点1024个
  54. keras.layers.Dense(1024, activation='relu', input_shape=(784,)),
  55. keras.layers.Dense(512, activation='relu'),
  56. keras.layers.Dense(10, activation='softmax')
  57. ])
  58. # 编译模型
  59. model.compile(optimizer='adam',
  60. loss='sparse_categorical_crossentropy',
  61. metrics=['accuracy'])
  62. # 使用训练集数据训练模型
  63. model.fit(mnist.train.images, train_labels, epochs=3)
  64. # 测试集的标签同样转成0-9数字
  65. test_labels = tf.argmax(mnist.test.labels, 1)
  66. # 使用测试集样本验证识别准确率
  67. test_loss, test_acc = model.evaluate(mnist.test.images, test_labels)
  68. print('\nTest accuracy:', test_acc)
  69. # 完整预测测试集样本
  70. predictions = model.predict(mnist.test.images)
  71. # 图示结果的前4*6个样本
  72. show_images(4, 6, mnist.test.images, mnist.test.labels, predictions)

代码讲解

通常我都是直接在注释中对程序做仔细的讲解,这次例外一下,因为我们需要从大局观上看清楚代码的结构。

这几行代码是定义神经网络模型:

  1. # 定义神经网络模型
  2. model = keras.Sequential([
  3. # 输入层为28x28共784个元素的数组,节点1024个
  4. keras.layers.Dense(1024, activation='relu', input_shape=(784,)),
  5. keras.layers.Dense(512, activation='relu'),
  6. keras.layers.Dense(10, activation='softmax')
  7. ])

每一行实际就代表一层神经网络的节点。在第一行中特别指明了输入数据的形式,即可以有未知数量的样本,每一个样本784个字节(28x28)。实际上这个输入样本可以不指定形状,在没有指定的情况下,Keras会自动识别训练数据集的形状,并自动将模型输入匹配到训练集形状。只是这种习惯并不一定好,除了效率问题,当样本集出错的时候,模型的定义也无法帮助开发者提前发现问题。所以建议产品化的模型,应当在模型中指定输入数据类型。

除了第一层之外,之后的每一层都无需指定输入样本形状。Keras会自动匹配相邻两个层的数据。这节省了开发人员大量的手工计算也不易出错。

最后,激活函数的选择成为一个参数。整体代码看上去简洁的令人惊讶。

接着在编译模型的代码中,直接指定Keras中预定义的“sparse_categorical_crossentropy”损失函数和“adam”优化算法。一个函数配合几个参数选择就完成了这部分工作:

  1. # 编译模型
  2. model.compile(optimizer='adam',
  3. loss='sparse_categorical_crossentropy',
  4. metrics=['accuracy'])

对原本复杂的训练循环部分,TensorFlow 2.0优化的最为彻底,只有一行代码:

  1. # 使用训练集数据训练模型
  2. model.fit(mnist.train.images, train_labels, epochs=3)

使用测试集数据对模型进行评估同样只需要一行代码,这里就不摘出来了,在上面完整代码中能看到。

可以想象,TensorFlow 2.0正式发布后,模型搭建、训练、评估的工作量大幅减少,会催生很多由实验性模型创新而出现的新算法。机器学习领域会再次涌现普及化浪潮。

这一版代码中,我们还细微修改了样本可视化部分的程序,将原来显示训练集样本,改为显示测试集样本。主要是增加了一个图片识别结果的参数。将图片的识别结果同数据集的标注一同显示在图片的下面作为对比。

程序运行的时候,控制台输出如下:

  1. $ python3 mnist-show-predict-pic-v1.py
  2. Extracting data/train-images-idx3-ubyte.gz
  3. Extracting data/train-labels-idx1-ubyte.gz
  4. Extracting data/t10k-images-idx3-ubyte.gz
  5. Extracting data/t10k-labels-idx1-ubyte.gz
  6. Epoch 1/3
  7. 55000/55000 [==============================] - 17s 307us/sample - loss: 0.1869 - accuracy: 0.9420
  8. Epoch 2/3
  9. 55000/55000 [==============================] - 17s 304us/sample - loss: 0.0816 - accuracy: 0.9740
  10. Epoch 3/3
  11. 55000/55000 [==============================] - 16s 298us/sample - loss: 0.0557 - accuracy: 0.9821
  12. 10000/10000 [==============================] - 1s 98us/sample - loss: 0.0890 - accuracy: 0.9743
  13. Test accuracy: 0.9743

最终的结果表示,模型通过3次的训练迭代之后。使用测试集数据进行验证,手写体数字识别正确率为97.43%。

程序最终会显示测试集前24个图片及预测结果和标注信息的对比:

(待续...)

TensorFlow从1到2(二)续讲从锅炉工到AI专家的更多相关文章

  1. TensorFlow从1到2(一)续讲从锅炉工到AI专家

    引言 原来引用过一个段子,这里还要再引用一次.是关于苹果的.大意是,苹果发布了新的开发语言Swift,有非常多优秀的特征,于是很多时髦的程序员入坑学习.不料,经过一段头脑体操一般的勤学苦练,发现使用S ...

  2. Tensorflow深度学习之十二:基础图像处理之二

    Tensorflow深度学习之十二:基础图像处理之二 from:https://blog.csdn.net/davincil/article/details/76598474   首先放出原始图像: ...

  3. 二十一世纪计算 | John Hopcroft:AI革命

    编者按:信息革命的浪潮浩浩汤汤,越来越多的人将注意力转向人工智能,想探索它对人类生产生活所产生的可能影响.人工智能的下一步发展将主要来自深度学习,在这个领域中,更多令人兴奋的话题在等待我们探讨:神经网 ...

  4. 人工智能(AI)库TensorFlow 踩坑日记之二

    上次 踩坑日志之一 遗留的问题终于解决了,所以作者(也就是我)终于有脸出来写第二篇了. 首先还是贴上 卷积算法的示例代码地址 :https://github.com/tensorflow/models ...

  5. Android学习之基础知识十二 — 第一讲:网络技术的使用

    这一节主要讲如何在手机端使用HTTP协议和服务器端进行网络交互,并对服务器返回的数据进行解析,这也是Android中最常用的网络技术. 一.WebView的用法 有时候我们可能会碰到比较特殊的需求,比 ...

  6. Tensorflow简单实践系列(二):张量

    在上一节中,我们安装 TensorFlow 并运行了最简单的应用,这节我们熟悉 TensorFlow 中的张量. 张量是 TensorFlow 的核心数据类型.数学里面也有张量的概念,但是 Tenso ...

  7. 使用Tensorflow搭建回归预测模型之二:数据准备与预处理

    前言: 在前一篇中,已经搭建好了Tensorflow环境,本文将介绍如何准备数据与预处理数据. 正文: 在机器学习中,数据是非常关键的一个环节,在模型训练前对数据进行准备也预处理是非常必要的. 一.数 ...

  8. Google Tensorflow 源码编译(二):Bazel<v0.1.0>

    这几天终于把tensorflow安装上了,中间遇到过不少的问题,这里记录下来.供大家想源码安装的参考. 安装环境:POWER8处理器,Docker容器Ubuntu14.04镜像. Build Baze ...

  9. TF Boys (TensorFlow Boys ) 养成记(二)

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

随机推荐

  1. unity做游戏常用功能实现(一)多方向同时输入也能让物体正常移动

    -------小基原创,转载请给我一个面子 网上有很多讲输入控制如何移动,但是多数都是讲单一按下,对于同时按下2个或2个以上按键并没有说明怎么解决,这里小基研究了一下方便大家 (如果你直接写input ...

  2. SEO优化-robots.txt解读

    一.什么是robots.txt robots.txt 文件由一条或多条规则组成.每条规则可禁止(或允许)特定抓取工具抓取相应网站中的指定文件路径. 通俗一点的说法就是:告诉爬虫,我这个网站,你哪些能看 ...

  3. 浏览器渲染原理笔记 --《How Browser Work》读后总结

    综述 之前使用ExtJS时遇到一个问题:为什么依次设置多个组件的可见性界面会卡顿?在了解HTML的dom操作相关内容的时候也好奇这个东西到底是怎么回事,然后尤其搞不懂CSS和Html分管样式和网页结构 ...

  4. Python入门、练手、视频资源汇总,拿走别客气!

    摘要:为方便朋友,重新整理汇总,内容包括长期必备.入门教程.练手项目.学习视频. 一.长期必备. 1. StackOverflow,是疑难解答.bug排除必备网站,任何编程问题请第一时间到此网站查找. ...

  5. java基础- Collection和map

    使用构造方法时,需要保留一个无参的构造方法 静态方法可以直接通过类名来访问,而不用创建对象. -- Java代码的执行顺序: 静态变量初始化→静态代码块→初始化静态方法→初始化实例变量→代码块→构造方 ...

  6. python访问mysql

    1,下载mysql-connector-python-2.0.4  pythoin访问mysql需要有客户端,这个就是连接mysql的库 解压后如下图: 双击lib 以windows为例 把mysql ...

  7. sqlilabs 5

    第一个1不断返回true,2可以进行更改?id=-1' union select 1,2,3 and '1?id=-1' union select 1,2,3 and 1='1 ?id=-1' uni ...

  8. 什么是web service ?

    一.序言 大家或多或少都听过WebService(Web服务),有一段时间很多计算机期刊.书籍和网站都大肆的提及和宣传WebService技术,其中不乏很多吹嘘和做广告的成分.但是不得不承认的是Web ...

  9. java基础-学java util类库总结

    JAVA基础 Util包介绍 学Java基础的工具类库java.util包.在这个包中,Java提供了一些实用的方法和数据结构.本章介绍Java的实用工具类库java.util包.在这个包中,Java ...

  10. Kali Linux桥接模式配置DNS服务器

    操作环境: 虚拟机操作系统: Kali Linux 2017.2 虚拟化软件: VMWare Workstation 14 pro 操作前的准备: 在设置里将Kali的上网模式设置成"桥接模 ...