深度学习之循环神经网络RNN概述,双向LSTM实现字符识别
深度学习之循环神经网络RNN概述,双向LSTM实现字符识别
2. RNN概述
Recurrent Neural Network - 循环神经网络,最早出现在20世纪80年代,主要是用于时序数据的预测和分类。它的基本思想是:前向将上一个时刻的输出和本时刻的输入同时作为网络输入,得到本时刻的输出,然后不断地重复这个过程。后向通过BPTT(Back Propagation Through Time)算法来训练得到网络的权重。RNN比CNN更加彻底的是,CNN通过卷积运算共享权重从而减少计算量,而RNN从头到尾所有的权重都是公用的,不同的只是输入和上一时刻的输出。RNN的缺点在于长时依赖容易被遗忘,从而使得长时依赖序列的预测效果较差。
LSTM(Long Short Memory)是RNN最著名的一次改进,它借鉴了人类神经记忆的长短时特性,通过门电路(遗忘门,更新门)的方式,保留了长时依赖中较为重要的信息,从而使得RNN的性能大幅度的提高。
为了提高LSTM的计算效率,学术界又提供了很多变体形式,最著名的要数GRU(Gated Recurrent Unit),在减少一个门电路的前提下,仍然保持了和LSTM近似的性能,成为了语音和nlp领域的宠儿。
这篇文章翻译自海外著名的一篇RNN的科普博客,具有很好的借鉴意义。
3. 双向LSTM实现字符识别
下面的代码实现了一个双向的LSTM网络来进行mnist数据集的字符识别问题,双向的LSTM优于单向LSTM的是它可以同时利用过去时刻和未来时刻两个方向上的信息,从而使得最终的预测更加的准确。
Tensorflow提供了对LSTM Cell的封装,这里我们使用BasicLSTMCell,定义前向和后向的LSTM Cell:
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
然后通过static_bidrectional_rnn函数将这两个cell以及时序输入x进行整合:
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
lstm_fw_cell,
lstm_bw_cell,
x,
dtype=tf.float32
)
完整的代码如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
learning_rate = 0.01
max_samples = 400000
batch_size = 128
display_step = 10
n_input = 28
n_steps = 28
n_hidden = 256
n_classes = 10
x = tf.placeholder(tf.float32, [None, n_steps, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
weights = tf.Variable(tf.random_normal([2 * n_hidden, n_classes]))
biases = tf.Variable(tf.random_normal([n_classes]))
def BiRNN(x, weights, biases):
x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, [-1, n_input])
x = tf.split(x, n_steps)
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
lstm_fw_cell,
lstm_bw_cell,
x,
dtype=tf.float32
)
return tf.matmul(outputs[-1], weights) + biases
pred = BiRNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,
labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
step = 1
while step * batch_size < max_samples:
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter" + str(step * batch_size) + ", Minibatch Loss=" + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print ("Optimization Finishes!")
test_len = 50000
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing accuracy:",
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
这里选择了400000个sample进行训练,图像按行读入像素序列(总共n_step=28行),每128个样本看成一个batch做一次BPTT,每10个batch打印一次training loss。
Iter396800, Minibatch Loss=0.038339, Training Accuracy= 0.98438
Iter398080, Minibatch Loss=0.007602, Training Accuracy= 1.00000
Iter399360, Minibatch Loss=0.024104, Training Accuracy= 0.99219
Optimization Finishes!
取50000个样本作为测试集,准确率为:
('Testing accuracy:', 0.98680007)
可以发现,双向LSTM做图像分类虽然也有不错的性能,但是还是比CNN略微逊色。主要原因应该还是因为图像数据属于层次性比较高的数据,CNN能够逐层抽取图像的层次特征,从而达到比较高的精度。
但是可以想象,对于时序性比较强的无空间结构数据,RNN会有更加出色的表现。
===================
CNN是做图像识别的,对彩票一点用都没有。彩票预测分为两种,一直是M选N型,比如双色球,大乐透,另外一种是M选1型,比如福彩3d在各位上选一个。
M选1型 的可以用非线性回归算法进行预测。KNN这个是典型的非线下回归算法,测试效果并不理想。贝叶斯,随机森林,SVM, GBDT可以测试看看。
深度学习之循环神经网络RNN概述,双向LSTM实现字符识别的更多相关文章
- 深度学习之循环神经网络(RNN)
循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络,适合用于处理视频.语音.文本等与时序相关的问题.在循环神经网络中,神经元不但可以接收其他神经元 ...
- TensorFlow深度学习实战---循环神经网络
循环神经网络(recurrent neural network,RNN)-------------------------重要结构(长短时记忆网络( long short-term memory,LS ...
- TensorFlow深度学习笔记 循环神经网络实践
转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnotes 欢迎star,有问题可以到Issue区讨论 官方教程地址 视频/字幕下载 加 ...
- 开始学习深度学习和循环神经网络Some starting points for deep learning and RNNs
Bengio, LeCun, Jordan, Hinton, Schmidhuber, Ng, de Freitas and OpenAI have done reddit AMA's. These ...
- 深度学习原理与框架-RNN网络框架-LSTM框架 1.控制门单元 2.遗忘门单元 3.记忆门单元 4.控制门单元更新 5.输出门单元 6.LSTM网络结构
LSTM网络是有LSTM每个单元所串接而成的, 从下面可以看出RNN与LSTM网络的差异, LSTM主要有控制门单元和输出门单元组成 控制门单元又是由遗忘门单元和记忆门单元的加和组成. 1.控制门单元 ...
- 循环神经网络(RNN, Recurrent Neural Networks)介绍(转载)
循环神经网络(RNN, Recurrent Neural Networks)介绍 这篇文章很多内容是参考:http://www.wildml.com/2015/09/recurrent-neur ...
- 循环神经网络(RNN, Recurrent Neural Networks)介绍
原文地址: http://blog.csdn.net/heyongluoyao8/article/details/48636251# 循环神经网络(RNN, Recurrent Neural Netw ...
- 通过keras例子理解LSTM 循环神经网络(RNN)
博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...
- 深度学习:浅谈RNN、LSTM+Kreas实现与应用
主要针对RNN与LSTM的结构及其原理进行详细的介绍,了解什么是RNN,RNN的1对N.N对1的结构,什么是LSTM,以及LSTM中的三门(input.ouput.forget),后续将利用深度学习框 ...
随机推荐
- MongoDB pymongo模块 删除数据
使用user集合,删除user集合的数据 import pymongo mongo_client = pymongo.MongoClient( host='192.168.0.112', port=2 ...
- ftp工具类
package com.ytd.zjdlbb.service.zjdlbb; import java.io.File;import java.io.FileInputStream;import jav ...
- Spark Sql之ThriftServer和Beeline的使用
概述 ThriftServer相当于service层,而ThriftServer通过Beeline来连接数据库.客户端用于连接JDBC的Server的一个工具 步骤 1:启动metastore服务 . ...
- 日志文件系统syslog,syslog-ng
日志文件系统syslog,syslog-ng 余二五 2017-11-07 20:37:00 浏览127 评论0 日志 LOG 配置 主机 正则表达式 syslog 表达式 source file ...
- 虚拟机开启时 VMware Authorization Service 这个服务找不到的解决办法
有些时候我们启动虚拟机 会出现 The VMware Authorization Service is not running 正常情况下我们只要进 我的电脑-------> 管理------- ...
- 调用另一个文件的python代码【转载】
转自:https://blog.csdn.net/u010412719/article/details/47089883 例如我们有a.py和b.py两个文件,当我们需要在b.py文件中应用a.py中 ...
- NN中BP推导及w不能初始化为0
转自:为什么w不能初始化为0,而是要随机初始化?https://zhuanlan.zhihu.com/p/27190255 通俗理解BP.https://zhuanlan.zhihu.com/p/24 ...
- [LeetCode] 824. Goat Latin_Easy
A sentence S is given, composed of words separated by spaces. Each word consists of lowercase and up ...
- 时间序列模式(ARIMA)---Python实现
时间序列分析的主要目的是根据已有的历史数据对未来进行预测.如餐饮销售预测可以看做是基于时间序列的短期数据预测, 预测的对象时具体菜品的销售量. 1.时间序列算法: 常见的时间序列模型; 2.时序模 ...
- 关于原始input的一些事情
1.关于input type为number时 maxlength失效 <input class="myfrom-input" type="text" id ...