芝麻HTTP:TensorFlow LSTM MNIST分类
本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。
初始化
首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:
learning_rate = 1e- num_units = num_layer = input_size = time_step = total_steps = category_num = steps_per_validate = steps_per_test = batch_size = tf.placeholder(tf.int32, []) keep_prob = tf.placeholder(tf.float32, [])
然后还需要声明一下 MNIST 数据生成器:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:
x = tf.placeholder(tf.float32, [None, ]) y_label = tf.placeholder(tf.float32, [None, ])
这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。
接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:
x_shape = tf.reshape(x, [-, time_step, input_size])
RNN 层
接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。
所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:
def cell(num_units): cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units) return DropoutWrapper(cell, output_keep_prob=keep_prob)
这里还加入了 Dropout,来减少训练过程中的过拟合。
接下来我们再利用它来构建多层的 RNN:
cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])
注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。
接下来我们需要声明一个初始状态:
h0 = cells.zero_state(batch_size, dtype=tf.float32)
然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:
output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)
这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。
这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:
output = output[:, -, :]
或者直接取隐藏状态最后一层的 h 也是相同的:
h = hs[-].h
在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。
输出层
接下来我们再做一次线性变换和 Softmax 输出结果即可:
# Output Layer w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32) b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32) y = tf.matmul(output, w) + b # Loss cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)
这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。
训练和评估
最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:
# Train train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy) # Prediction correction_prediction = tf.equal(tf.argmax(y, axis=), tf.argmax(y_label, axis=)) accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32)) # Train with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ): batch_x, batch_y = mnist.train.next_batch() sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: ]}) # Train Accuracy : print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[]})) # Test Accuracy : test_x, test_y = mnist.test.images, mnist.test.labels print('Test', step, sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: , batch_size: test_x.shape[]}))
运行
直接运行之后,只训练了几轮就可以达到 98% 的准确率:
Train 0.27 Test 0.2223 Train 0.87 Train 0.91 Train 0.94 Train 0.94 Train 0.99 Test 0.9595 Train 0.95 Train 0.97 Train 0.98
可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。
芝麻HTTP:TensorFlow LSTM MNIST分类的更多相关文章
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 2、TensorFlow训练MNIST
装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- TensorFlow LSTM 注意力机制图解
TensorFlow LSTM Attention 机制图解 深度学习的最新趋势是注意力机制.在接受采访时,现任OpenAI研究主管的Ilya Sutskever提到,注意力机制是最令人兴奋的进步之一 ...
- TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人
简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...
- Ubuntu16.04安装TensorFlow及Mnist训练
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...
随机推荐
- BZOJ 2839: 集合计数 [容斥原理 组合]
2839: 集合计数 题意:n个元素的集合,选出若干子集使得交集大小为k,求方案数 先选出k个\(\binom{n}{k}\),剩下选出一些集合交集为空集 考虑容斥 \[ 交集为\emptyset = ...
- python---协程 学习笔记
协程 协程又称为微线程,协程是一种用户态的轻量级线程 协程拥有自己的寄存器和栈.协程调度切换的时候,将寄存器上下文和栈都保存到其他地方,在切换回来的时候,恢复到先前保存的寄存器上下文和栈,因此:协程能 ...
- Trie树/字典树题目(2017今日头条笔试题:异或)
/* 本程序说明: [编程题] 异或 时间限制:1秒 空间限制:32768K 给定整数m以及n个数字A1,A2,..An,将数列A中所有元素两两异或,共能得到n(n-1)/2个结果,请求出这些结果中大 ...
- u-boot核心初始化
异常向量表:异常:因为内部或者外部的一些事件,导致处理器停下正在处理的工作,转而去处理这些发生的事件.ARM Architecture Reference Manual p54页.7种异常的类型:Re ...
- Hive metastore源码阅读(三)
上次写了hive metastore的partition的生命周期,但是简略概括了下alter_partition的操作,这里补一下alter_partition,因为随着项目的深入,发现它涉及的地方 ...
- testng 异常 截图
testNG里有一个异常监听类,失败时会执行类里的相关方法 DriverBase 截图类 TestngListenerScreen 异常监听类 Test1 测试类1.DriverBase类 packa ...
- PHP die与exit的区别
最近听见有人说die和exit区别,bula~bula.决心一探究竟. 翻了翻PHP 5.6的源码(源码的位置为zend目录下zend_language_scanner.l大约是1014~1020行) ...
- 通过云主机(网关机)远程登录内网mysql
国内的一些云主机平台(UCloud,阿里云,腾讯云等)走的都是网关机+内网机(即局域网)模式,网关机代理外网访问,不能直接连接内网机器.本文介绍通过远程登录云主机,并设置本地代理的方式,通过sqlyo ...
- 基于Parallax设计HTML视差效果
年关将至,给大家拜年. 最近时间充裕了一点,给大家介绍一个比较有意思的控件:Parallax.它可以用来实现鼠标移动时,页面上的元素也做偏移的视差效果.在一些有表现层次,布局空旷的页面上,用来做Hea ...
- Dynamics 365 Online-多选域
参与过Dynamics CRM相关工作的朋友们都知道,Dynamics 365之前并没有多选域字段,想要实现多选域,需要自己添加WebResource定制,而这也带来了一系列需要考虑的情况,比如额外的 ...