Ternsorflow 学习:005-MNIST入门 实现模型
前言
在上一讲中,我们通过分析选用了softmax模型,并用tf创建之。本讲的内容就是为了训练这个模型以便于测试。
训练模型
为了训练我们的模型,我们首先需要定义一个指标来评估这个模型是好的。其实,在机器学习,我们通常定义指标来表示一个模型是坏的,这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。
用于衡量模型好坏的工具,可以称之为成本函数。有关概念
一个非常常见的,非常漂亮的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。它的定义如下:
\]
y 是我们预测的概率分布, y' 是实际的分布(我们输入的one-hot vector)。比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。更详细的关于交叉熵的解释超出本教程的范畴,但是你很有必要好好理解它。
为了计算交叉熵,我们首先需要添加一个新的占位符用于输入正确值:
y_ = tf.placeholder("float", [None,10])
# 注意: 这里的y_ 相当于公式中的y'
然后我们可以用下面公式计算交叉熵:
\]
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
首先,用 tf.log
计算 y
的每个元素的对数。
接下来,我们把 y_
的每一个元素和 tf.log(y_)
的对应元素相乘。
最后,用 tf.reduce_sum
计算张量的所有元素的总和。
注意,这里的交叉熵不仅仅用来衡量单一的一对预测和真实值,而是所有100幅图片的交叉熵的总和。
对于100个数据点的预测表现比单一数据点的表现能更好地描述我们的模型的性能。
现在我们知道我们需要我们的模型做什么啦,用TensorFlow来训练它是非常容易的。因为TensorFlow拥有一张描述你各个计算单元的图,它可以自动地使用反向传播算法(backpropagation algorithm)来有效地确定你的变量是如何影响你想要最小化的那个成本值的。然后,TensorFlow会用你选择的优化算法来不断地修改变量以降低成本。
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动。
当然TensorFlow也提供了其他许多优化算法:只要简单地调整一行代码就可以使用其他的算法。
TensorFlow在这里实际上所做的是,它会在后台给描述你的计算的那张图里面增加一系列新的计算操作单元用于实现反向传播算法和梯度下降算法。然后,它返回给你的只是一个单一的操作,当运行这个操作时,它用梯度下降算法训练你的模型,微调你的变量,不断减少成本。
至此,我们的模型已经设置好了。 接下来就是运行计算。
在运行计算之前,我们需要添加一个操作来初始化我们创建的变量:
init = tf.initialize_all_variables()
在一个Session
里面启动我们的模型,并且初始化变量:
sess = tf.Session()
sess.run(init)
然后开始训练模型,这里我们让模型循环训练1000次。
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行train_step
。
使用一小部分的随机数据来进行训练被称为随机训练(stochastic training)。
在这里更确切的说是随机梯度下降训练。在理想情况下,我们希望用我们所有的数据来进行每一步的训练,因为这能给我们更好的训练结果,但显然这需要很大的计算开销。所以,每一次训练我们可以使用不同的数据子集,这样做既可以减少计算开销,又可以最大化地学习到数据集的总体特性。
评估我们的模型
那么我们的模型性能如何呢?
首先,找出那些预测正确的标签:tf.argmax
是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。
由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签。
比如
tf.argmax(y,1)
返回的是模型对于任一输入x预测到的标签值,而tf.argmax(y_,1)
代表正确的标签,我们可以用tf.equal
来检测我们的预测是否真实标签匹配(索引位置一样表示匹配),函数会返回一组布尔数组。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。
例如,
[True, False, True, True]
会变成[1,0,1,1]
,取平均值后得到0.75
.
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
最后,我们计算所学习到的模型在测试数据集上面的正确率。
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
这个最终结果值应该大约是91%。
WARNING:tensorflow:From /home/schips/tf-demo/tensorflow-dev/lib/python3.5/site-packages/tensorflow/python/util/tf_should_use.py:193: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Usetf.global_variables_initializer
instead.
WARNING:tensorflow:From tf_201_mnist_1.py:40: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.
2020-01-07 22:21:27.767925: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-01-07 22:21:27.776861: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2903995000 Hz
2020-01-07 22:21:27.777645: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4fd2c50 executing computations on platform Host. Devices:
2020-01-07 22:21:27.777685: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): ,
2020-01-07 22:21:27.787265: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1412] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU. To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
0.918
这个结果好吗?嗯,并不太好。事实上,这个结果是很差的。这是因为我们仅仅使用了一个非常简单的模型。不过,做一些小小的改进,我们就可以得到97%的正确率。最好的模型甚至可以获得超过99.7%的准确率!(想了解更多信息,可以看看这个关于各种模型的性能对比列表。)
比结果更重要的是,我们从这个模型中学习到的设计思想。不过,如果你仍然对这里的结果有点失望,可以查看下一个教程,在那里你可以学习如何用FensorFlow构建更加复杂的模型以获得更好的性能!
附录: 最近几章用的程序
import tensorflow as tf
# 上上一讲的内容
from mnist import input_data
mnist = input_data.read_data_sets("mnist-datasets/", one_hot=True)
# 通过操作符号变量来描述这些可交互的操作单元,可以用下面的方式创建一个
x = tf.placeholder("float", [None, 784])
# 我们的模型也需要权重值和偏置量,当然我们可以把它们当做是另外的输入(使用占位符),但TensorFlow有一个更好的方法来表示它们:Variable 。 一个Variable代表一个可修改的张量,存在在TensorFlow的用于描述交互性操作的图中。它们可以用于计算输入值,也可以在计算中被修改。对于各种机器学习应用,一般都会有模型参数,可以用Variable表示。
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
# 我们赋予tf.Variable不同的初值来创建不同的Variable:在这里,我们都用全为零的张量来初始化W和b。因为我们要学习W和b的值,它们的初值可以随意设置。
# 注意,W的维度是[784,10],因为我们想要用784维的图片向量乘以它以得到一个10维的证据值向量,每一位对应不同数字类。b的形状是[10],所以我们可以直接把它加到输出上面。
# 模型的实现
y = tf.nn.softmax(tf.matmul(x,W) + b)
# 为了计算交叉熵,我们首先需要添加一个新的占位符用于输入正确值
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 先用 tf.log 计算 y 的每个元素的对数。
# 接下来,我们把 y_ 的每一个元素和 tf.log(y_) 的对应元素相乘。
# 最后,用 tf.reduce_sum 计算张量的所有元素的总和。
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。
# 梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动。
## 至此,我们的模型已经设置好了。 接下来就是运行计算环境
# 运行计算之前,我们需要添加一个操作来初始化我们的变量
init = tf.initialize_all_variables()
# 在一个Session里面启动我们的模型,并且初始化变量:
sess = tf.Session()
sess.run(init)
# 开始训练模型,这里我们让模型循环训练1000次!
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# 该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行train_step。
# 评估训练模型
# 1. 找出那些预测正确的标签 tf.i#complete_opened(1)
#tf=argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。
#由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# 将bool转为float
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 运行并输出结果
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
Ternsorflow 学习:005-MNIST入门 实现模型的更多相关文章
- Ternsorflow 学习:004-MNIST入门 构建模型
Softmax回归介绍 我们知道MNIST的每一张图片都表示一个数字,从0到9.我们希望得到给定图片代表每个数字的概率.比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是8 ...
- Ternsorflow 学习:003-MNIST入门有关概念
前言 当我们开始学习编程的时候,第一件事往往是学习打印"HelloWorld".就好比编 程入门有 HelloWorld,机器学习入门有 MNIST. MNIST 是一个入门级的计 ...
- Ternsorflow 学习:006-MNIST进阶 深入MNIST
前言 这篇文章适合实践过MNIST入门的人学习观看.没有看过MNIST基础的人请移步这里 深入MNIST TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训 ...
- 腾讯QQ会员技术团队:人人都可以做深度学习应用:入门篇(下)
四.经典入门demo:识别手写数字(MNIST) 常规的编程入门有"Hello world"程序,而深度学习的入门程序则是MNIST,一个识别28*28像素的图片中的手写数字的程序 ...
- MongoDB学习笔记:快速入门
MongoDB学习笔记:快速入门 一.MongoDB 简介 MongoDB 是由C++语言编写的,是一个基于分布式文件存储的开源数据库系统.在高负载的情况下,添加更多的节点,可以保证服务器性能.M ...
- 博主从零开始学习HTML(入门基础)
目录 从零开始学习HTML(入门基础) 互联网三大基石 HTML的Head标签中的常用元素 字体格式化标签 字符实体,以下写最常用的几个 html常用标签及解析 a标签 img标签 媒体标签audio ...
- 从零开始学习jQuery (一) 入门篇
本系列文章导航 从零开始学习jQuery (一) 入门篇 一.摘要 本系列文章将带您进入jQuery的精彩世界, 其中有很多作者具体的使用经验和解决方案, 即使你会使用jQuery也能在阅读中发现些 ...
- python学习笔记--Django入门四 管理站点--二
接上一节 python学习笔记--Django入门四 管理站点 设置字段可选 编辑Book模块在email字段上加上blank=True,指定email字段为可选,代码如下: class Autho ...
- Thinkphp入门 五 —模型 (49)
原文:Thinkphp入门 五 -模型 (49) [数据库操作model模型] model 模型 数据库操作 tp框架主要设计模式:MVC C:controller 控制器 shop/Li ...
随机推荐
- [原]greenplum安装详细过程
今天又帮其他项目装了一遍GP,加上之前的两次,这是第三次了,虽然每次都有记录,但这次安装还是发现漏写了一些步骤,在此详细记录一下,需要的童鞋可以借鉴. 1.准备 这里准备了4台服务器,1台做maste ...
- request库解析中文
官网地址: http://cn.python-requests.org/zh_CN/latest/ 高级用法 本篇文档涵盖了 Requests 的一些高级特性. 会话对象 会话对象让你能够跨请求保持某 ...
- php面试题之PHP核心技术
一.PHP核心技术 更多PHP相关知识请关注我的专栏PHPzhuanlan.zhihu.com 1.写出一个能创建多级目录的PHP函数(新浪网技术部) <?php /** * 创建多级目录 * ...
- c#DDOS代码
//在工程属性中设置"允许不安全代码"为true ?using System; using System.Net; using System.Net.Sockets; using ...
- 区间树Splay——[NOI2005]维护数列
无指针Splay超详细讲解 区间树这玩意真TM玄学. 学这东西你必须要拥有的 1.通过[模板]文艺平衡树(Splay),[模板]普通平衡树,GSS3 - Can you answer these qu ...
- selenium+chrome options
selenium+chrome options 环境:selenium chrome 1. selenium + chrome参数配置 1.1. 启动 from selenium im ...
- angular 自定义服务封装自定义http请求
在angular中将http请求,放置在一起封装成服务,可减少代码重复,方便使用 var ngpohttprest = angular.module('ngpohttprest', []); ngpo ...
- 「NOI2005」维护数列
「NOI2005」维护数列 传送门 维护过程有点像线段树. 但我们知道线段树的节点并不是实际节点,而平衡树的节点是实际节点. 所以在向上合并信息时要加入根节点信息. 然后节点再删除后编号要回退(栈), ...
- MVC webuploader 图片
AARTO:SaveNoticeAndDocument ~/Scripts/Upload/webuploader-0.1.4/dist/webuploader.js
- Xcode Edit Schemes
关于本文:有关“Xcode Edit Schemes”的设置,还是有很大的学问的.由于时间关系,我一点一点的补充. 1.在开发的时候,至少将Run的Build Configuration设置为Debu ...