tensorflow学习笔记七----------RNN
和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。
1.将数据处理成序列化;
2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;
3.所有数据都处理完。
导入数据
- import tensorflow as tf
- import from tensorflow.examples.tutorials.mnist import input_data
- import numpy as np
- import matplotlib.pyplot as plt
- print ("Packages imported")
- mnist = input_data.read_data_sets("data/", one_hot=True)
- trainimgs, trainlabels, testimgs, testlabels \
- = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
- ntrain, ntest, dim, nclasses \
- = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
- print ("MNIST loaded")
将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;
- diminput = 28
- dimhidden = 128
- dimoutput = nclasses
- nsteps = 28
- weights = {
- 'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
- 'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
- }
- biases = {
- 'hidden': tf.Variable(tf.random_normal([dimhidden])),
- 'out': tf.Variable(tf.random_normal([dimoutput]))
- }
定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O
- def _RNN(_X, _W, _b, _nsteps, _name):
- # 1. Permute input from [batchsize, nsteps, diminput]
- # => [nsteps, batchsize, diminput]
- _X = tf.transpose(_X, [1, 0, 2])
- # 2. Reshape input to [nsteps*batchsize, diminput]
- _X = tf.reshape(_X, [-1, diminput])
- # 3. Input layer => Hidden layer
- _H = tf.matmul(_X, _W['hidden']) + _b['hidden']
- # 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
- _Hsplit = tf.split(0, _nsteps, _H)
- # 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
- # Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
- # Only _LSTM_O will be used to predict the output.
- with tf.variable_scope(_name) as scope:
- scope.reuse_variables()
- lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
- _LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
- # 6. Output
- _O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
- # Return!
- return {
- 'X': _X, 'H': _H, 'Hsplit': _Hsplit,
- 'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
- }
- print ("Network ready")
定义好RNN后,定义损失函数等
- learning_rate = 0.001
- x = tf.placeholder("float", [None, nsteps, diminput])
- y = tf.placeholder("float", [None, dimoutput])
- myrnn = _RNN(x, weights, biases, nsteps, 'basic')
- pred = myrnn['O']
- cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
- optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
- accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
- init = tf.global_variables_initializer()
- print ("Network Ready!")
进行训练
- training_epochs = 5
- batch_size = 16
- display_step = 1
- sess = tf.Session()
- sess.run(init)
- print ("Start optimization")
- for epoch in range(training_epochs):
- avg_cost = 0.
- total_batch = int(mnist.train.num_examples/batch_size)
- # Loop over all batches
- for i in range(total_batch):
- batch_xs, batch_ys = mnist.train.next_batch(batch_size)
- batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
- # Fit training using batch data
- feeds = {x: batch_xs, y: batch_ys}
- sess.run(optm, feed_dict=feeds)
- # Compute average loss
- avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
- # Display logs per epoch step
- if epoch % display_step == 0:
- print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
- feeds = {x: batch_xs, y: batch_ys}
- train_acc = sess.run(accr, feed_dict=feeds)
- print (" Training accuracy: %.3f" % (train_acc))
- testimgs = testimgs.reshape((ntest, nsteps, diminput))
- feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
- test_acc = sess.run(accr, feed_dict=feeds)
- print (" Test accuracy: %.3f" % (test_acc))
- print ("Optimization Finished.")
tensorflow学习笔记七----------RNN的更多相关文章
- tensorflow学习笔记七----------卷积神经网络
卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- (转)Qt Model/View 学习笔记 (七)——Delegate类
Qt Model/View 学习笔记 (七) Delegate 类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
随机推荐
- PHP基础教程探讨一些php编程性能优化总结
兄弟连PHP培训 小编最近在做php程序的性能优化,一些经过测试后发现的东西就先记录下来,以备后用. 首先对于一些反应慢的操作或页面要跟踪处理一下,可以使用webGrind的方式看一下主要问题出在 ...
- luoguP1197 [JSOI2008]星球大战 x
P1197 [JSOI2008]星球大战 题目描述 很久以前,在一个遥远的星系,一个黑暗的帝国靠着它的超级武器统治者整个星系.某一天,凭着一个偶然的机遇,一支反抗军摧毁了帝国的超级武器,并攻下了星系中 ...
- 【BZOJ3876】 [Ahoi2014]支线剧情
Description [故事背景] 宅男JYY非常喜欢玩RPG游戏,比如仙剑,轩辕剑等等.不过JYY喜欢的并不是战斗场景,而是类似电视剧一般的充满恩怨情仇的剧情.这些游戏往往 都有很多的支线剧情,现 ...
- 在线前端 JS 或 HTML 或 CSS 编写 Demo 处 JSbin 与 jsFiddle 比较
JSBin 该编辑器的特点是编写可直接编写 HTML.CSS.JavaScript 并且可以在 output 中实时观察编写效果:可设置自动运行 JavasScript 代码,其中最大的好处是有一个 ...
- (58)PHP开发
LAMP 0.使用include和require命令来包含外部PHP文件. 使用include_once命令,但是include和include_once命令相比的不足就是这两个命令并不关心请求的文件 ...
- Logger工具类
org.slf4j.Logger的简单封装,传入所在类的class,和类名或全类名. public class LoggerFactory { public static Logger getLogg ...
- 3D Computer Grapihcs Using OpenGL - 13 优化矩阵
上节说过矩阵是可以结合的,而且相乘是按照和应用顺序相反的顺序进行的.我们之前初始化translationMatrix和rotationMatrix的时候,第一个参数都是使用的一个初始矩阵 glm::m ...
- Java 有几种修饰符?分别用来修饰什么
4种修饰符 访问权限 类 包 子类 其他包 public ∨ ∨ ∨ ∨ protect ∨ ∨ ∨ × default ∨ ∨ ...
- [C#菜鸟]C# Hook (一)
转过来的文章,出处已经不知道了,但只这篇步骤比较清晰,就贴出来了. 一.写在最前 本文的内容只想以最通俗的语言说明钩子的使用方法,具体到钩子的详细介绍可以参照下面的网址: http://www.mic ...
- 开发-日常工具:TFS(Team Foundation Server)
ylbtech-开发-日常工具:TFS(Team Foundation Server) TFS(Team Foundation Server)是一个高可扩展.高可用.高性能.面向互联网服务的分布式文件 ...