TensorFlow MNIST初级学习
MNIST
MNIST 是一个入门级计算机视觉数据集,包含了很多手写数字图片,如图所示:
数据集中包含了图片和对应的标注,在 TensorFlow 中提供了这个数据集,我们可以用如下方法进行导入:
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data/', one_hot=True) print(mnist)
输出结果如下:
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x101707ef0>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1016ae4a8>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1016f9358>)
在这里程序会首先下载 MNIST 数据集,然后解压并保存到刚刚制定好的 MNIST_data 文件夹中,然后输出数据集对象。
数据集中包含了 55000 行的训练数据集(mnist.train)、5000 行验证集(mnist.validation)和 10000 行的测试数据集(mnist.test),文件如下所示:
正如前面提到的一样,每一个 MNIST 数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为 xs,把这些标签设为 ys。训练数据集和测试数据集都包含 xs 和 ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels,每张图片是 28 x 28 像素,即 784 个像素点,我们可以把它展开形成一个向量,即长度为 784 的向量。
所以训练集我们可以转化为 [55000, 784] 的向量,第一维就是训练集中包含的图片个数,第二维是图片的像素点表示的向量。
Softmax
Softmax 可以看成是一个激励(activation)函数或者链接(link)函数,把我们定义的线性函数的输出转换成我们想要的格式,也就是关于 10 个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被 Softmax 函数转换成为一个概率值。Softmax 函数可以定义为:
展开等式右边的子式,可以得到:
比如判断一张图片中的动物是什么,可能的结果有三种,猫、狗、鸡,假如我们可以经过计算得出它们分别的得分为 3.2、5.1、-1.7,Softmax 的过程首先会对各个值进行次幂计算,分别为 24.5、164.0、0.18,然后计算各个次幂结果占总次幂结果的比重,这样就可以得到 0.13、0.87、0.00 这三个数值,所以这样我们就可以实现差别的放缩,即好的更好、差的更差。
如果要进一步求损失值可以进一步求对数然后取负值,这样 Softmax 后的值如果值越接近 1,那么得到的值越小,即损失越小,如果越远离 1,那么得到的值越大。
实现回归模型
首先导入 TensorFlow,命令如下:
import tensorflow as tf
接下来我们指定一个输入,在这里输入即为样本数据,如果是训练集那么则是 55000 x 784 的矩阵,如果是验证集则为 5000 x 784 的矩阵,如果是测试集则是 10000 x 784 的矩阵,所以它的行数是不确定的,但是列数是确定的。
所以可以先声明一个 placeholder 对象:
x = tf.placeholder(tf.float32, [None, ])
这里第一个参数指定了矩阵中每个数据的类型,第二个参数指定了数据的维度。
接下来我们需要构建第一层网络,表达式如下:
这里实际上是对输入的 x 乘以 w 权重,然后加上一个偏置项作为输出,而这两个变量实际是在训练的过程中动态调优的,所以我们需要指定它们的类型为 Variable,代码如下:
w = tf.Variable(tf.zeros([, ])) b = tf.Variable(tf.zeros([]))
接下来需要实现的就是上图所述的公式了,我们再进一步调用 Softmax 进行计算,得到 y:
y = tf.nn.softmax(tf.matmul(x, w) + b)
通过上面几行代码我们就已经把模型构建完毕了,结构非常简单。
损失函数
为了训练我们的模型,我们首先需要定义一个指标来评估这个模型是好的。其实,在机器学习,我们通常定义指标来表示一个模型是坏的,这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。但是这两种方式是相同的。
一个非常常见的,非常漂亮的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。它的定义如下:
y 是我们预测的概率分布, y_label 是实际的分布,比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。
我们可以首先定义 y_label,它的表达式是:
y_label = tf.placeholder(tf.float32, [None, ])
接下来我们需要计算它们的交叉熵,代码如下:
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(y), reduction_indices=[]))
首先用 reduce_sum() 方法针对每一个维度进行求和,reduction_indices 是指定沿哪些维度进行求和。
然后调用 reduce_mean() 则求平均值,将一个向量中的所有元素求算平均值。
这样我们最后只需要优化这个交叉熵就好了。
所以这样我们再定义一个优化方法:
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
这里使用了 GradientDescentOptimizer,在这里,我们要求 TensorFlow 用梯度下降算法(gradient descent algorithm)以 0.5 的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow 只需将每个变量一点点地往使成本不断降低的方向移动即可。
运行模型
定义好了以上内容之后,相当于我们已经构建好了一个计算图,即设置好了模型,我们把它放到 Session 里面运行即可:
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ): batch_x, batch_y = mnist.train.next_batch(batch_size) sess.run(train, feed_dict={x: batch_x, y_label: batch_y})
该循环的每个步骤中,我们都会随机抓取训练数据中的 batch_size 个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行 train。
这里需要一些变量的定义:
batch_size = total_steps =
测试模型
那么我们的模型性能如何呢?
首先让我们找出那些预测正确的标签。tf.argmax() 是一个非常有用的函数,它能给出某个 Tensor 对象在某一维上的其数据最大值所在的索引值。由于标签向量是由 0,1 组成,因此最大值 1 所在的索引位置就是类别标签,比如 tf.argmax(y, 1) 返回的是模型对于任一输入 x 预测到的标签值,而 tf.argmax(y_label, 1) 代表正确的标签,我们可以用 tf.equal() 方法来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。
correct_prediction = tf.equal(tf.argmax(y, axis=), tf.argmax(y_label, axis=))
这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1, 0, 1, 1] ,取平均值后得到 0.75。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
最后,我们计算所学习到的模型在测试数据集上面的正确率,定义如下:
steps_per_test = : print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
这个最终结果值应该大约是92%。
这样我们就通过完成了训练和测试阶段,实现了一个基本的训练模型,后面我们会继续优化模型来达到更好的效果。
运行结果如下:
0.453 0.8915 0.9026 0.9081 0.9109 0.9108 0.9175 0.9137 0.9158 0.9176 0.9167 0.9186 0.9206 0.9161 0.9218 0.9179 0.916 0.9196 0.9222 0.921 0.9223 0.9214 0.9191 0.9228 0.9228 0.9218 0.9197 0.9225 0.9238 0.9219 0.9224 0.9184 0.9253 0.9216 0.9218 0.9212 0.9225 0.9224 0.9225 0.9226 0.9201 0.9138 0.9184 0.9222 0.92 0.924 0.9234 0.9219 0.923 0.9254 0.9218
结语
本节通过一个 MNIST 数据集来简单体验了一下真实数据的训练和预测过程,但是准确率还不够高,后面我们会学习用卷积的方式来进行模型训练,准确率会更高。
TensorFlow MNIST初级学习的更多相关文章
- 学习笔记TF056:TensorFlow MNIST,数据集、分类、可视化
MNIST(Mixed National Institute of Standards and Technology)http://yann.lecun.com/exdb/mnist/ ,入门级计算机 ...
- TensorFlow和深度学习入门教程(TensorFlow and deep learning without a PhD)【转】
本文转载自:https://blog.csdn.net/xummgg/article/details/69214366 前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络,并把 ...
- TensorFlow和深度学习新手教程(TensorFlow and deep learning without a PhD)
前言 上月导师在组会上交我们用tensorflow写深度学习和卷积神经网络.并把其PPT的參考学习资料给了我们, 这是codelabs上的教程:<TensorFlow and deep lear ...
- Mac tensorflow mnist实例
Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...
- TensorFlow MNIST(手写识别 softmax)实例运行
TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...
- 基于TensorFlow的深度学习系列教程 2——常量Constant
前面介绍过了Tensorflow的基本概念,比如如何使用tensorboard查看计算图.本篇则着重介绍和整理下Constant相关的内容. 基于TensorFlow的深度学习系列教程 1--Hell ...
- TensorFlow MNIST 问题解决
TensorFlow MNIST 问题解决 一.数据集下载错误 错误:IOError: [Errno socket error] [Errno 101] Network is unreachable ...
- TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载
http://blog.csdn.net/scotfield_msn/article/details/60339415 在TensorFlow (RNN)深度学习下 双向LSTM(BiLSTM)+CR ...
- TensorFlow机器学习框架-学习笔记-001
# TensorFlow机器学习框架-学习笔记-001 ### 测试TensorFlow环境是否安装完成-----------------------------```import tensorflo ...
随机推荐
- console引起的eclipse 僵死/假死 问题排查及解决[转]
原文链接:http://www.iteye.com/topic/1133941 症状: 使用Eclipse win 64位版本,indigo及kepler都重现了,使用tomcat 6.0.39,jd ...
- NOIP 2017 day -1 杂记
我几乎要崩溃了. 写任何板子都是第一遍一定写不对,后来发现是傻逼性错误. 好奇怪的,这些东西明明我都会,为什么现在我都忘了? 很烦.现在心里特别乱,写什么都写不下去. 可能我是真的无法放心这次的比赛. ...
- bzoj 4033: [HAOI2015]树上染色 [树形DP]
4033: [HAOI2015]树上染色 我写的可是\(O(n^2)\)的树形背包! 注意j倒着枚举,而k要正着枚举,因为k可能从0开始,会使用自己更新一次 #include <iostream ...
- Docker小记 — Docker Engine
前言 用了Docker方才觉得生产环境终于有了他该有的样子,就像集装箱普及之后大型货轮的价值才逐渐体现出来,Docker详细说明可查阅"官方文档".本篇为Docker Engine ...
- IntelliJ IDEA下Maven SpringMVC+Mybatis入门搭建例子
很久之前写了一篇SSH搭建例子,由于工作原因已经转到SpringMVC+Mybatis,就以之前SSH实现简单登陆的例子,总结看看SpringMVC+Mybatis怎么实现. Spring一开始是轻量 ...
- asp.net core 使用html文件
在asp.net core 项目中,使用html文件一般通过使用中间件来提供服务: 打开 NuGet程序管理控制台 输入install-package Microsoft.aspnetcore.sta ...
- SSE图像算法优化系列十六:经典USM锐化中的分支判断语句SSE实现的几种方法尝试。
分支判断的语句一般来说是不太适合进行SSE优化的,因为他会破坏代码的并行性,但是也不是所有的都是这样的,在合适的场景中运用SSE还是能对分支预测进行一定的优化的,我们这里以某一个算法的部分代码为例进行 ...
- Phalcon调试大杀器之phalcon-debugbar安装
Phalcon 是一款非常火的高性能C扩展php开发框架.特点是高性能低耦合,但遗憾的是长期缺少一款得力的调试辅助工具. 目前版本主要以Laravel debugbar的具有功能为蓝本开发,并针对ph ...
- 浏览器中显示PPT的展示效果
发现了一个PPT的WEb展示的方法,在浏览器中载入PDF文件之后,可以实现基于WEB的页面展示,支持全屏与自动播放. https://sharedoc.onk.ninja/ 这不失为一种可行的方式,且 ...
- 异步任务利器Celery(一)介绍
django项目开发中遇到过一些问题,发送请求后服务器要进行一系列耗时非常长的操作,用户要等待很久的时间.可不可以立刻对用户返回响应,然后在后台运行那些操作呢? crontab定时任务很难达到这样的要 ...