个基于TensorFlow的简单故事生成案例:带你了解LSTM
https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537
在深度学习中,循环神经网络(RNN)是一系列善于从序列数据中学习的神经网络。由于对长期依赖问题的鲁棒性,长短期记忆(LSTM)是一类已经有实际应用的循环神经网络。现在已有大量关于 LSTM 的文章和文献,其中推荐如下两篇:
Goodfellow et.al.《深度学习》一书第十章:http://www.deeplearningbook.org/
Chris Olah:理解 LSTM:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
已存在大量优秀的库可以帮助你基于 LSTM 构建机器学习应用。在 GitHub 中,谷歌的 TensorFlow 在此文成文时已有超过 50000 次星,表明了其在机器学习从业者中的流行度。
与此形成对比,相对缺乏的似乎是关于如何基于 LSTM 建立易于理解的 TensorFlow 应用的优秀文档和示例,这也是本文尝试解决的问题。
假设我们想用一个样本短故事来训练 LSTM 预测下一个单词,伊索寓言:
long ago , the mice had a general council to consider what measures they could take to outwit their common enemy , the cat . some said this , and some said that but at last a young mouse got up and said he had a proposal to make , which he thought would meet the case . you will all agree , said he , that our chief danger consists in the sly and treacherous manner in which the enemy approaches us . now , if we could receive some signal of her approach , we could easily escape from her . i venture , therefore , to propose that a small bell be procured , and attached by a ribbon round the neck of the cat . by this means we should always know when she was about , and could easily retire while she was in the neighbourhood . this proposal met with general applause , until an old mouse got up and said that is all very well , but who is to bell the cat ? the mice looked at one another and nobody spoke . then the old mouse said it is easy to propose impossible remedies .
Listing 1.取自伊索寓言的短故事,其中有 112 个不同的符号。单词和标点符号都视作符号。
如果我们将文本中的 3 个符号以正确的序列输入 LSTM,以 1 个标记了的符号作为输出,最终神经网络将学会正确地预测下一个符号(Figure1)。
图 1.有 3 个输入和 1 个输出的 LSTM 单元
严格说来,LSTM 只能理解输入的实数。一种将符号转化为数字的方法是基于每个符号出现的频率为其分配一个对应的整数。例如,上面的短文中有 112 个不同的符号。如列表 2 所示的函数建立了一个有如下条目 [「,」: 0 ] [「the」: 1 ], …, [「council」: 37 ],…,[「spoke」= 111 ] 的词典。而为了解码 LSTM 的输出,同时也生成了逆序字典。
def build_dataset(words):
count = collections.Counter(words).most_common()
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return dictionary, reverse_dictionary
Listing 2.建立字典和逆序字典的函数
类似地,预测值也是一个唯一的整数值与逆序字典中预测符号的索引相对应。例如:如果预测值是 37,预测符号便是「council」。
输出的生成看起来似乎简单,但实际上 LSTM 为下一个符号生成了一个含有 112 个元素的预测概率向量,并用 softmax() 函数归一化。有着最高概率值的元素的索引便是逆序字典中预测符号的索引值(例如:一个 one-hot 向量)。图 2 给出了这个过程。
图 2.每一个输入符号被分配给它的独一无二的整数值所替代。输出是一个表明了预测符号在反向词典中索引的 one-hot 向量。
LSTM 模型是这个应用的核心部分。令人惊讶的是,它很易于用 TensorFlow 实现:
def RNN(x, weights, biases): # reshape to [1, n_input]
x = tf.reshape(x, [-1, n_input]) # Generate a n_input-element sequence of inputs
# (eg. [had] [a] [general] -> [20] [6] [33])
x = tf.split(x,n_input,1) # 1-layer LSTM with n_hidden units.
rnn_cell = rnn.BasicLSTMCell(n_hidden) # generate prediction
outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32) # there are n_input outputs but
# we only want the last output
return tf.matmul(outputs[-1], weights['out']) + biases['out']
Listing 3.有 512 个 LSTM 单元的网络模型
最难部分是以正确的格式和顺序完成输入。在这个例子中,LSTM 的输入是一个有 3 个整数的序列(例如:1x3 的整数向量)
网络的常量、权值和偏差设置如下:
vocab_size = len(dictionary)
n_input = 3
# number of units in RNN cell
n_hidden = 512
# RNN output node weights and biases
weights = {
'out': tf.Variable(tf.random_normal([n_hidden, vocab_size]))
}
biases = {
'out': tf.Variable(tf.random_normal([vocab_size]))
}
Listing 4.常量和训练参数
训练过程中的每一步,3 个符号都在训练数据中被检索。然后 3 个符号转化为整数以形成输入向量。
symbols_in_keys = [ [dictionary[ str(training_data[i])]] for i in range(offset, offset+n_input) ]
Listing 5.将符号转化为整数向量作为输入
训练标签是一个位于 3 个输入符号之后的 one-hot 向量。
symbols_out_onehot = np.zeros([vocab_size], dtype=float)
symbols_out_onehot[dictionary[str(training_data[offset+n_input])]] = 1.0
Listing 6.单向量作为标签
在转化为输入词典的格式后,进行如下的优化过程:
_, acc, loss, onehot_pred = session.run([optimizer, accuracy, cost, pred], feed_dict={x: symbols_in_keys, y: symbols_out_onehot})
Listing 7.训练过程中的优化
精度和损失被累积以监测训练过程。通常 50,000 次迭代足以达到可接受的精度要求。
...
Iter= 49000, Average Loss= 0.528684, Average Accuracy= 88.50%
['could', 'easily', 'retire'] - [while] vs [while]
Iter= 50000, Average Loss= 0.415811, Average Accuracy= 91.20%
['this', 'means', 'we'] - [should] vs [should]
Listing 8.一个训练间隔的预测和精度数据示例(间隔 1000 步)
代价是标签和 softmax() 预测之间的交叉熵,它被 RMSProp 以 0.001 的学习率进行优化。在本文示例的情况中,RMSProp 通常比 Adam 和 SGD 表现得更好。
pred = RNN(x, weights, biases) # Loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost)
Listing 9.损失和优化器
LSTM 的精度可以通过增加层来改善。
rnn_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(n_hidden),rnn.BasicLSTMCell(n_hidden)])
Listing 10. 改善的 LSTM
现在,到了有意思的部分。让我们通过将预测得到的输出作为输入中的下一个符号输入 LSTM 来生成一个故事吧。示例输入是「had a general」,LSTM 给出了正确的输出预测「council」。然后「council」作为新的输入「a general council」的一部分输入神经网络得到下一个输出「to」,如此循环下去。令人惊讶的是,LSTM 创作出了一个有一定含义的故事。
had a general council to consider what measures they could take to outwit their common enemy , the cat . some said this , and some said that but at last a young mouse got
Listing 11.截取了样本故事生成的故事中的前 32 个预测值
如果我们输入另一个序列(例如:「mouse」,「mouse」,「mouse」)但并不一定是这个故事中的序列,那么会自动生成另一个故事。
mouse mouse mouse , neighbourhood and could receive a outwit always the neck of the cat . some said this , and some said that but at last a young mouse got up and said
Listing 12.并非来源于示例故事中的输入序列
示例代码可以在这里找到:https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/RNN/rnn_words.py
示例文本的链接在这里:https://github.com/roatienza/Deep-Learning-Experiments/blob/master/Experiments/Tensorflow/RNN/belling_the_cat.txt
小贴士:
1. 用整数值编码符号容易操作但会丢失单词的意思。本文中将符号转化为整数值是用来简化关于用 TensorFlow 建立 LSTM 应用的讨论的。更推荐采用 Word2Vec 将符号编码为向量。
2. 将输出表达成单向量是效率较低的方式,尤其当我们有一个现实的单词量大小时。牛津词典有超过 170,000 个单词,而上面的例子中只有 112 个单词。再次声明,本文中的示例只为了简化讨论。
3. 这里采用的代码受到了 Tensorflow-Examples 的启发:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py
4. 本文例子中的输入大小为 3,看一看当采用其它大小的输入时会发生什么吧(例如:4,5 或更多)。
5. 每次运行代码都可能生成不同的结果,LSTM 的预测能力也会不同。这是由于精度依赖于初始参数的随机设定。训练次数越多(超过 150,000 次)精度也会相应提高。每次运行代码,建立的词典也会不同
6. Tensorboard 在调试中,尤其当检查代码是否正确地建立了图时很有用。
7. 试着用另一个故事测试 LSTM,尤其是用另一种语言写的故事。
个基于TensorFlow的简单故事生成案例:带你了解LSTM的更多相关文章
- 基于TensorFlow的简单验证码识别
TensorFlow 可以用来实现验证码识别的过程,这里识别的验证码是图形验证码,首先用标注好的数据来训练一个模型,然后再用模型来实现这个验证码的识别. 生成验证码 首先生成验证码,这里使用 Pyth ...
- 利用tensorflow训练简单的生成对抗网络GAN
对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...
- 深度学习(五)基于tensorflow实现简单卷积神经网络Lenet5
原文作者:aircraft 原文地址:https://www.cnblogs.com/DOMLX/p/8954892.html 参考博客:https://blog.csdn.net/u01287127 ...
- 基于tensorflow的简单线性回归模型
#!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...
- 基于tensorflow的简单鼠标键盘识别
import cv2 as cvimport tensorflow as tfimport numpy as npimport random ##以下为数据预处理,分类为cata,总共样本为cata* ...
- 基于Tensorflow + Opencv 实现CNN自定义图像分类
摘要:本篇文章主要通过Tensorflow+Opencv实现CNN自定义图像分类案例,它能解决我们现实论文或实践中的图像分类问题,并与机器学习的图像分类算法进行对比实验. 本文分享自华为云社区< ...
- 基于 TensorFlow 在手机端实现文档检测
作者:冯牮 前言 本文不是神经网络或机器学习的入门教学,而是通过一个真实的产品案例,展示了在手机客户端上运行一个神经网络的关键技术点 在卷积神经网络适用的领域里,已经出现了一些很经典的图像分类网络,比 ...
- 基于TensorFlow的深度学习系列教程 2——常量Constant
前面介绍过了Tensorflow的基本概念,比如如何使用tensorboard查看计算图.本篇则着重介绍和整理下Constant相关的内容. 基于TensorFlow的深度学习系列教程 1--Hell ...
- Python3+Dlib实现简单人脸识别案例
Python3+Dlib实现简单人脸识别案例 写在前边 很早很早之前,当我还是一个傻了吧唧的专科生的时候,我就听说过人脸识别,听说过算法,听说过人工智能,并且也出生牛犊不怕虎般的学习过TensorFl ...
随机推荐
- 2016 Multi-University Training Contest 6 题解
我只能说: A Boring Question 下面公式重复了一行 \[ \sum\_{0\leq k\_{1},k\_{2},\cdots k\_{m}\leq n}\prod\_{1\leq j& ...
- 51Nod 1092 回文字符串(LCS + dp)
51Nod 1092 数据结构暑假作业上出现的一题,学习了一下相关算法之后,找到了oj测试能AC. 1.回文串是一种中心对称的结构,这道题可以转变为求最长回文子序列长度的题目.(子序列:可以不连续) ...
- Bus Blaster v4 design overview
Bus Blaster v4 design overview Bus Blaster v4 is an experimental, high-speed JTAG debugger for ARM p ...
- Introduction to Cortex Serial Wire Debugging
Serial Wire Debug (SWD) provides a debug port for severely pin limited packages, often the case for ...
- Mysql5.6主从复制-基于binlog
MySQL5.6开始主从复制有两种方式:基于日志(binlog):基于GTID(全局事务标示符). 此文章是基于日志方式的配置步骤 环境: master数据库IP:192.168.247.128sla ...
- 阅读Linux内核源码时建立tags索引
比如在阅读arm架构的Linux内核代码时想建立arm架构相关的索引,可以使用下面的命令: make ARCH=arm tags
- (ios7) 解决代码布局View, ios7 中 subView 高度增加StatusBar20dp的问题,保证Ios6,ios7代码一致
在ios7 布局中,Status Bar 和 ToolBar ,NavigateBar 等都包含在ViewControl的主View中. 这样原来ios6 的View布局 整体向上移动了20dp,下面 ...
- 【python】python安装步骤
1.官网下载python 官网地址:https://www.python.org/getit/ 2.下载完成后点击安装 勾选Add python to PATH 是可以自己去配置环境变量的 注意:这里 ...
- java把指定文字输出为图片流,支持文字换行
public class IamgeUtils { private static final int WIDTH = 350; private static final int HEIGHT = 10 ...
- 美国罪案故事第一季/全集American Crime Story迅雷下载
英文全名American Crime Story,第1季(2016)FX.本季看点:<美国罪案故事>以律师们的视角看待辛普森谋杀案. 本剧探索了案件背后各种混乱,以及案件双方的庭审策略,也 ...