tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别。
而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时间轴上慢慢展开,有点类似我们大脑认识事物时会有相关的短期记忆。
这次我们使用RNN来识别手写数字。
首先导入数据并定义各种RNN的参数:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类
接着定义输入、输出以及各权重的形状:
# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
# 对weights和biases初始值定义
weights = {
# shape(28, 128)
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
# shape(128 , 10)
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
# shape(128, )
'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
# shape(10, )
'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
定义 RNN 的主体结构
最主要的就是定义RNN的主体结构。
def RNN(X, weights, biases):
# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)
X = tf.reshape(X, [-1, n_inputs])
# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)
X_in = tf.matmul(X, weights['in']) + biases['in']
# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])
# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
# 初始化全0state
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True
# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
# 计算结果,其中states[1]为分线state,也就是最后一个输出值
results = tf.matmul(states[1], weights['out']) + biases['out']
return results
训练RNN
定义好了 RNN 主体结构后, 我们就可以来计算 cost 和 train_op:
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)
训练时, 不断输出 accuracy, 观看结果:
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < training_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))
step += 1
最终 accuracy 的结果如下:
E:\Python\Python36\python.exe E:/learn/numpy/lesson3/main.py
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
2018-02-20 20:30:52.769108: I C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\36\tensorflow\core\platform\cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX
0.09375
0.710938
0.8125
0.789063
0.820313
0.882813
0.828125
0.867188
0.921875
0.90625
0.921875
0.890625
0.898438
0.945313
0.914063
0.945313
0.929688
0.96875
0.96875
0.929688
0.953125
0.945313
0.960938
0.992188
0.953125
0.9375
0.929688
0.96875
0.960938
0.945313
完整代码
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类
# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
# 对weights和biases初始值定义
weights = {
# shape(28, 128)
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
# shape(128 , 10)
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
# shape(128, )
'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
# shape(10, )
'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}
def RNN(X, weights, biases):
# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)
X = tf.reshape(X, [-1, n_inputs])
# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)
X_in = tf.matmul(X, weights['in']) + biases['in']
# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])
# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
# 初始化全0state
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True
# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。
outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
# 计算结果,其中states[1]为分线state,也就是最后一个输出值
results = tf.matmul(states[1], weights['out']) + biases['out']
return results
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < training_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))
step += 1
tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】的更多相关文章
- tensorflow分类-【老鱼学tensorflow】
前面我们学习过回归问题,比如对于房价的预测,因为其预测值是个连续的值,因此属于回归问题. 但还有一类问题属于分类的问题,比如我们根据一张图片来辨别它是一只猫还是一只狗.某篇文章的内容是属于体育新闻还是 ...
- tensorflow用dropout解决over fitting-【老鱼学tensorflow】
在机器学习中可能会存在过拟合的问题,表现为在训练集上表现很好,但在测试集中表现不如训练集中的那么好. 图中黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训 ...
- tensorflow卷积神经网络-【老鱼学tensorflow】
前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...
- tensorflow例子-【老鱼学tensorflow】
本节主要用一个例子来讲述一下基本的tensorflow用法. 在这个例子中,我们首先伪造一些线性数据点,其实这些数据中本身就隐藏了一些规律,但我们假装不知道是什么规律,然后想通过神经网络来揭示这个规律 ...
- tensorflow Tensorboard可视化-【老鱼学tensorflow】
tensorflow自带了可视化的工具:Tensorboard.有了这个可视化工具,可以让我们在调整各项参数时有了可视化的依据. 本次我们先用Tensorboard来可视化Tensorflow的结构. ...
- tensorflow安装-【老鱼学tensorflow】
TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,Tensor ...
- tensorflow 传入值-【老鱼学tensorflow】
上个文章中讲述了tensorflow中如何定义变量以及如何读取变量的方式,本节主要讲述关于传入值. 变量主要用于在tensorflow系统中经常会被改变的值,而对于传入值,它只是当tensorflow ...
- tensorflow建造神经网络-【老鱼学tensorflow】
上次我们添加了一个add_layer函数,这次就要创建一个神经网络来预测/拟合相应的数据. 下面我们先来创建一下虚拟的数据,这个数据为二次曲线数据,但同时增加了一些噪点,其图像为: 相应的创建这些伪造 ...
- tensorflow Tensorboard2-【老鱼学tensorflow】
前面我们用Tensorboard显示了tensorflow的程序结构,本节主要用Tensorboard显示各个参数值的变化以及损失函数的值的变化. 这里的核心函数有: histogram 例如: tf ...
随机推荐
- windows下网络编程UDP
转载 C++ UDP客户端服务器Socket编程 UDPServer.cpp #include<winsock2.h>#include<stdio.h>#include< ...
- 使用VLC Activex插件做网页版视频播放器
网上找的一个小例子,包括时长播放时间等等都有. mrl可以设置本地文件,这样发布网站后只能播放本地有的文件, 如果视频文件全在服务器上,其他电脑想看的话,则可以IIS上发布个视频文件服务器,类似htt ...
- .Net Core在Centos7上初体验
本文主要内容是简单介绍如何在centos7上开发.Net Core项目,在此之前我们首先了解下.Net Core的基本特性. 1 .Net Core和.Net FrameWork的异同 1.1 .Ne ...
- python3 练手实例4 九九乘法口诀表
for i in range(1,10): for j in range(1,i+1): print('{}*{}={}\t'.format(i,j,i*j),end='') print()
- python复习2
在操作字符串时,我们经常遇到str和bytes的互相转换.为了避免乱码问题,应当始终坚持使用UTF-8编码对str和bytes进行转换.
- wifi基本原理
参考链接: https://www.cnblogs.com/zhoading/p/8891206.html https://openwrt.org/zh-cn/doc/uci/wireless htt ...
- [转] word2vec
from: https://www.cnblogs.com/peghoty/p/3857839.html 另附一个比较好的介绍:https://zhuanlan.zhihu.com/p/2630679 ...
- #2018-2019-2-20175204 张湲祯 实验一 《Java开发环境的熟悉》实验报告
2018-2019-2-20175204 张湲祯 实验一 <Java开发环境的熟悉>实验报告 一.实验内容及步骤 一.使用JDK编译.运行简单的Java程序 1.输入cd zyz命令进入z ...
- ZYNQ基础知识一
参考:UG1181 Zynq-7000 Programable Soc Architrcture Porting Quick Start Guide ...
- vue面试题总结
1.vue双向绑定的实现原理2.js的继承和原型链3.es6语法箭头函数和普通函数的区别 普通函数的this总是指向它的直接调用者. 在严格模式下,没找到直接调用者,则函数中的this是undefin ...