和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。

1.将数据处理成序列化;

2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;

3.所有数据都处理完。

导入数据

import tensorflow as tf
import from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported") mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")

将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;

diminput  = 28
dimhidden = 128
dimoutput = nclasses
nsteps = 28
weights = {
'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([dimhidden])),
'out': tf.Variable(tf.random_normal([dimoutput]))
}

定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O

def _RNN(_X, _W, _b, _nsteps, _name):
# 1. Permute input from [batchsize, nsteps, diminput]
# => [nsteps, batchsize, diminput]
_X = tf.transpose(_X, [1, 0, 2])
# 2. Reshape input to [nsteps*batchsize, diminput]
_X = tf.reshape(_X, [-1, diminput])
# 3. Input layer => Hidden layer
_H = tf.matmul(_X, _W['hidden']) + _b['hidden']
# 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
_Hsplit = tf.split(0, _nsteps, _H)
# 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
# Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
# Only _LSTM_O will be used to predict the output.
with tf.variable_scope(_name) as scope: scope.reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
_LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
# 6. Output
_O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
# Return!
return {
'X': _X, 'H': _H, 'Hsplit': _Hsplit,
'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
}
print ("Network ready")

定义好RNN后,定义损失函数等

learning_rate = 0.001
x = tf.placeholder("float", [None, nsteps, diminput])
y = tf.placeholder("float", [None, dimoutput])
myrnn = _RNN(x, weights, biases, nsteps, 'basic')
pred = myrnn['O']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
init = tf.global_variables_initializer()
print ("Network Ready!")

进行训练

training_epochs = 5
batch_size = 16
display_step = 1
sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size) # Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
# Fit training using batch data
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
print ("Optimization Finished.")

tensorflow学习笔记七----------RNN的更多相关文章

  1. tensorflow学习笔记七----------卷积神经网络

    卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

  4. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  5. Tensorflow学习笔记No.10

    多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...

  6. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  7. (转)Qt Model/View 学习笔记 (七)——Delegate类

    Qt Model/View 学习笔记 (七) Delegate  类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...

  8. Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...

  9. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

随机推荐

  1. layui 源码解读(部分)

    <!DOCTYPE html> <head> </head> <body> <input type="button" id=& ...

  2. Codeforces 912D Fishs ( 贪心 && 概率期望 && 优先队列 )

    题意 : 给出一个 N * M 的网格,然后给你 K 条鱼给你放置,现有规格为 r * r 的渔网,问你如果渔网随意放置去捕捞小鱼的情况下,捕到的最大期望值是多少? 分析 :  有一个很直观的想法就是 ...

  3. 大数据笔记(三十)——一篇文章读懂SparkSQL

    Spark SQL:类似Hive ======================================================= 一.Spark SQL基础 1.什么是Spark SQ ...

  4. libusb开发者指南

      本文档描述libusb的API,以及如何开发USB应用.1 介绍 1.1 概览本文档描述libusb-0.1的API和USB相关内容.1.2 当前OS支持Linux 2.2或以上FreeBSD/N ...

  5. Ansible安装及常用模块

    配置文件:/etc/ansible/ansible.cfg 主机列表:/etc/ansible/hosts  安装anslibe  wget -O /etc/yum.repos.d/epel.repo ...

  6. c++11多线程---线程入口函数

    1.普通函数(线程入口) #include <thread> #include <iostream> void hello(const char *name) { std::c ...

  7. leetcode 125 验证回文字符串 Valid Palindrome

    验证回文字符串 C++ 思路就是先重新定义一个string ,先遍历第一遍,字符串统一小写,去除空格:然后遍历第二遍,首尾一一对应比较:时间复杂度O(n+n/2),空间O(n); class Solu ...

  8. Html/CSS 示例演练 图书馆后台界面

    示例演练(html css javascript) --制作图书馆后台界面 1. 成品图

  9. await Vue.nextTick() 的含义分析

    概述 今天看别人的单元测试代码的时候碰到了一段代码 await Vue.nextTick(),初看起来不是很懂,后来通过查资料弄懂了,记录下来,供以后开发时参考,相信对其他人也有用. await Vu ...

  10. ActionList及Action使用

    ActionList及Action使用 https://blog.csdn.net/adamrao/article/details/7450889 2012年04月11日 19:09:27 阅读数:1 ...