tensorflow实现循环神经网络
包括卷积神经网络(CNN)在内的各种前馈神经网络模型, 其一次前馈过程的输出只与当前输入有关与历史输入无关.
递归神经网络(Recurrent Neural Network, RNN)充分挖掘了序列数据中的信息, 在时间序列和自然语言处理方面有着重要的应用.
递归神经网络可以展开为普通的前馈神经网络:
长短期记忆模型(Long-Short Term Memory)是RNN的常用实现. 与一般神经网络的神经元相比, LSTM神经元多了一个遗忘门.
LSTM神经元的输出除了与当前输入有关外, 还与自身记忆有关. RNN的训练算法也是基于传统BP算法增加了时间考量, 称为BPTT(Back-propagation Through Time)算法.
使用tensorflow内置rnn
tensorflow内置了递归神经网络的实现:
from tensorflow.python.ops import rnn, rnn_cell
tensorflow目前正在快速迭代中, 上述路径可能会发生变化.在0.6.0版本中上述路径是有效的.
官方教程中已经加入了循环神经网络的部分, API可能不会发生太大变化.
Tensorflow有多种rnn神经元可供选择:
rnn_cell.BasicLSTMCell
rnn_cell.LSTMCell
rnn_cell.GRUCell
这里我们选用最简单的BasicLSTMCell, 需要设置神经元个数和forget_bias
参数:
self.lstm_cell = rnn_cell.BasicLSTMCell(hidden_n, forget_bias=1.0)
可以直接调用cell对象获得输出和状态:
output, state = cell(inputs, state)
使用dropout避免过拟合问题:
from tensorflow.python.ops.rnn_cell import Dropoutwrapper
cells = DropoutWrapper(lstm_cell, input_keep_prob=0.5, output_keep_prob=0.5)
使用MultiRNNCell来创建多层神经网络:
from tensorflow.python.ops.rnn_cell import MultiRNNCell
cells = MultiRNNCell([lstm_cell_1, lstm_cell_2])
不过rnn.rnn
可以替我们完成神经网络的构建工作:
outputs, states = rnn.rnn(self.lstm_cell, self.input_layer, dtype=tf.float32)
再加一个输出层进行输出:
self.prediction = tf.matmul(outputs[-1], self.weights) + self.biases
定义损失函数:
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.prediction, self.label_layer))
使用Adam优化器进行训练:
self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)
因为神经网络需要处理序列数据, 所以输入层略复杂:
self.input_layer = [tf.placeholder("float", [step_n, input_n]) for i in range(batch_size)]
tensorflow要求RNNCell的输入为一个列表, 列表中的每一项作为一个批次进行训练.
列表中的每一个元素代表一个序列, 每一行为序列中的一项. 这样每一项为一个形状为(序列长, 输入维数)的矩阵.
标签还是和原来一样为形如(序列长, 输出维度)的矩阵:
self.label_layer = tf.placeholder("float", [step_n, output_n])
执行训练:
self.session.run(initer)
for i in range(limit):
self.session.run(self.trainer, feed_dict={self.input_layer[0]: train_x[0], self.label_layer: train_y})
因为input_layer
为列表, 而列表不能作为字典的键.所以我们只能采用{self.input_layer[0]: train_x[0]}
这样的方式输入数据.
可以看到lable_layer
也是二维的, 并没有输入多个批次的数据. 考虑到这两点, 目前这个实现并不具备多批次处理的能力.
序列的长度通常是不同的, 而目前的实现采用的是定长输入. 这是需要解决的另一个难题.
完整源代码可以在demo.py中查看.
tensorflow实现循环神经网络的更多相关文章
- 基于TensorFlow的循环神经网络(RNN)
RNN适用场景 循环神经网络(Recurrent Neural Network)适合处理和预测时序数据 RNN的特点 RNN的隐藏层之间的节点是有连接的,他的输入是输入层的输出向量.extend(上一 ...
- tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...
- Tensorflow中循环神经网络及其Wrappers
tf.nn.rnn_cell.LSTMCell 又名:tf.nn.rnn_cell.BasicLSTMCell.tf.contrib.rnn.LSTMCell 参见: tf.nn.rnn_cell.L ...
- TensorFlow系列专题(七):一文综述RNN循环神经网络
欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 前言 RNN知识结构 简单循环神经网络 RNN的基本结构 RNN的运算过程 ...
- 4.5 RNN循环神经网络(recurrent neural network)
自己开发了一个股票智能分析软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1 RNN循环神经网络 ...
- 学习笔记TF057:TensorFlow MNIST,卷积神经网络、循环神经网络、无监督学习
MNIST 卷积神经网络.https://github.com/nlintz/TensorFlow-Tutorials/blob/master/05_convolutional_net.py .Ten ...
- 学习笔记TF053:循环神经网络,TensorFlow Model Zoo,强化学习,深度森林,深度学习艺术
循环神经网络.https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/re ...
- TensorFlow——循环神经网络基本结构
1.导入依赖包,初始化一些常量 import collections import numpy as np import tensorflow as tf TRAIN_DATA = "./d ...
- TensorFlow学习笔记(六)循环神经网络
一.循环神经网络简介 循环神经网络的主要用途是处理和预测序列数据.循环神经网络刻画了一个序列当前的输出与之前信息的关系.从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面节点的输出. ...
随机推荐
- 《C#从现象到本质》读书笔记(五)第5章字符串第6章垃圾回收第7章异常与异常处理
<C#从现象到本质>读书笔记(五)第5章字符串 字符串是引用类型,但如果在某方法中,将字符串传入另一方法,在另一方法内部修改,执行完之后,字符串的只并不会改变,而引用类型无论是按值传递还是 ...
- js--随机产生100个从0 ~ 1000之间不重复的整数(me)
<style> div{text-indent:40px;} </style> <script> window.onload=function(){ v ...
- 复制命令(XCOPY)
XCOPY 命令: // 描述: 将文件或目录(包括子目录)从一个位置复制到另一个位置. // 语法: Xcopy <Source> [<Destination>] [/w] ...
- Arrays工具类和Collections工具类
集合知识点总结 Arrays工具类 .binarySearch() .sort() .fill() //填充 int[] array = new int[10]; Arrays.fill(array, ...
- 安装及使用virtualenv
安装tensorflow之virtualenv 在安装之前首先保证ubuntu.python,以及一些相应的包安装成功. 1.安装virtualenv#(1) pip $ sudo apt-get i ...
- ubuntu16.04 下安装 visual studio code 以及利用 g++ 运行 c++程序
参考链接:1. http://www.linuxidc.com/Linux/2016-07/132798.htm(安装vs code) 2.https://blog.csdn.net/qq_28598 ...
- 【转】nc 使用说明
netcat是网络工具中的瑞士军刀,它能通过TCP和UDP在网络中读写数据.通过与其他工具结合和重定向,你可以在脚本中以多种方式使用它.使用netcat命令所能完成的事情令人惊讶. netcat所做的 ...
- qhfl-7 结算中心
结算中心,即从购物车前往支付前的确认页面,这里要开始选择优惠券了 """ 前端传过来数据 course_list 课程列表 redis 中将要存放的结算数据 { sett ...
- shell中与C语言中的区别
shell中为啥与C语言有区别呢?弄成一样的不是很好么,其实不然,shell提供很多操作,这些操作不单单是执行程序或者命令,在很多时候是执行脚本的,简单的shell就是脚本编程,它的主要目的是处理文件 ...
- hdu 4027 Can you answer these queries?[线段树]
题目 题意: 输入一个 :n .(1<=n<<100000) 输入n个数 (num<2^63) 输入一个m :代表m个操作 (1<=m<<100 ...