和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。

1.将数据处理成序列化;

2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;

3.所有数据都处理完。

导入数据

  1. import tensorflow as tf
  2. import from tensorflow.examples.tutorials.mnist import input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. print ("Packages imported")
  6.  
  7. mnist = input_data.read_data_sets("data/", one_hot=True)
  8. trainimgs, trainlabels, testimgs, testlabels \
  9. = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
  10. ntrain, ntest, dim, nclasses \
  11. = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
  12. print ("MNIST loaded")

将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;

  1. diminput = 28
  2. dimhidden = 128
  3. dimoutput = nclasses
  4. nsteps = 28
  5. weights = {
  6. 'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
  7. 'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
  8. }
  9. biases = {
  10. 'hidden': tf.Variable(tf.random_normal([dimhidden])),
  11. 'out': tf.Variable(tf.random_normal([dimoutput]))
  12. }

定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O

  1. def _RNN(_X, _W, _b, _nsteps, _name):
  2. # 1. Permute input from [batchsize, nsteps, diminput]
  3. # => [nsteps, batchsize, diminput]
  4. _X = tf.transpose(_X, [1, 0, 2])
  5. # 2. Reshape input to [nsteps*batchsize, diminput]
  6. _X = tf.reshape(_X, [-1, diminput])
  7. # 3. Input layer => Hidden layer
  8. _H = tf.matmul(_X, _W['hidden']) + _b['hidden']
  9. # 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
  10. _Hsplit = tf.split(0, _nsteps, _H)
  11. # 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
  12. # Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
  13. # Only _LSTM_O will be used to predict the output.
  14. with tf.variable_scope(_name) as scope:
  15.  
  16. scope.reuse_variables()
  17. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
  18. _LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
  19. # 6. Output
  20. _O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
  21. # Return!
  22. return {
  23. 'X': _X, 'H': _H, 'Hsplit': _Hsplit,
  24. 'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
  25. }
  26. print ("Network ready")

定义好RNN后,定义损失函数等

  1. learning_rate = 0.001
  2. x = tf.placeholder("float", [None, nsteps, diminput])
  3. y = tf.placeholder("float", [None, dimoutput])
  4. myrnn = _RNN(x, weights, biases, nsteps, 'basic')
  5. pred = myrnn['O']
  6. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
  7. optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
  8. accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
  9. init = tf.global_variables_initializer()
  10. print ("Network Ready!")

进行训练

  1. training_epochs = 5
  2. batch_size = 16
  3. display_step = 1
  4. sess = tf.Session()
  5. sess.run(init)
  6. print ("Start optimization")
  7. for epoch in range(training_epochs):
  8. avg_cost = 0.
  9. total_batch = int(mnist.train.num_examples/batch_size)
  10.  
  11. # Loop over all batches
  12. for i in range(total_batch):
  13. batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  14. batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
  15. # Fit training using batch data
  16. feeds = {x: batch_xs, y: batch_ys}
  17. sess.run(optm, feed_dict=feeds)
  18. # Compute average loss
  19. avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
  20. # Display logs per epoch step
  21. if epoch % display_step == 0:
  22. print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
  23. feeds = {x: batch_xs, y: batch_ys}
  24. train_acc = sess.run(accr, feed_dict=feeds)
  25. print (" Training accuracy: %.3f" % (train_acc))
  26. testimgs = testimgs.reshape((ntest, nsteps, diminput))
  27. feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
  28. test_acc = sess.run(accr, feed_dict=feeds)
  29. print (" Test accuracy: %.3f" % (test_acc))
  30. print ("Optimization Finished.")

tensorflow学习笔记七----------RNN的更多相关文章

  1. tensorflow学习笔记七----------卷积神经网络

    卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

  4. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  5. Tensorflow学习笔记No.10

    多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...

  6. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  7. (转)Qt Model/View 学习笔记 (七)——Delegate类

    Qt Model/View 学习笔记 (七) Delegate  类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...

  8. Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...

  9. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

随机推荐

  1. PHP基础教程探讨一些php编程性能优化总结

      兄弟连PHP培训 小编最近在做php程序的性能优化,一些经过测试后发现的东西就先记录下来,以备后用. 首先对于一些反应慢的操作或页面要跟踪处理一下,可以使用webGrind的方式看一下主要问题出在 ...

  2. luoguP1197 [JSOI2008]星球大战 x

    P1197 [JSOI2008]星球大战 题目描述 很久以前,在一个遥远的星系,一个黑暗的帝国靠着它的超级武器统治者整个星系.某一天,凭着一个偶然的机遇,一支反抗军摧毁了帝国的超级武器,并攻下了星系中 ...

  3. 【BZOJ3876】 [Ahoi2014]支线剧情

    Description [故事背景] 宅男JYY非常喜欢玩RPG游戏,比如仙剑,轩辕剑等等.不过JYY喜欢的并不是战斗场景,而是类似电视剧一般的充满恩怨情仇的剧情.这些游戏往往 都有很多的支线剧情,现 ...

  4. 在线前端 JS 或 HTML 或 CSS 编写 Demo 处 JSbin 与 jsFiddle 比较

    JSBin 该编辑器的特点是编写可直接编写 HTML.CSS.JavaScript 并且可以在 output 中实时观察编写效果:可设置自动运行 JavasScript 代码,其中最大的好处是有一个 ...

  5. (58)PHP开发

    LAMP 0.使用include和require命令来包含外部PHP文件. 使用include_once命令,但是include和include_once命令相比的不足就是这两个命令并不关心请求的文件 ...

  6. Logger工具类

    org.slf4j.Logger的简单封装,传入所在类的class,和类名或全类名. public class LoggerFactory { public static Logger getLogg ...

  7. 3D Computer Grapihcs Using OpenGL - 13 优化矩阵

    上节说过矩阵是可以结合的,而且相乘是按照和应用顺序相反的顺序进行的.我们之前初始化translationMatrix和rotationMatrix的时候,第一个参数都是使用的一个初始矩阵 glm::m ...

  8. Java 有几种修饰符?分别用来修饰什么

    4种修饰符 访问权限   类   包  子类  其他包 public     ∨   ∨   ∨     ∨ protect    ∨   ∨   ∨     × default    ∨   ∨   ...

  9. [C#菜鸟]C# Hook (一)

    转过来的文章,出处已经不知道了,但只这篇步骤比较清晰,就贴出来了. 一.写在最前 本文的内容只想以最通俗的语言说明钩子的使用方法,具体到钩子的详细介绍可以参照下面的网址: http://www.mic ...

  10. 开发-日常工具:TFS(Team Foundation Server)

    ylbtech-开发-日常工具:TFS(Team Foundation Server) TFS(Team Foundation Server)是一个高可扩展.高可用.高性能.面向互联网服务的分布式文件 ...