Deep Learning基础--理解LSTM/RNN中的Attention机制
导读
目前采用编码器-解码器 (Encode-Decode) 结构的模型非常热门,是因为它在许多领域较其他的传统模型方法都取得了更好的结果。这种结构的模型通常将输入序列编码成一个固定长度的向量表示,对于长度较短的输入序列而言,该模型能够学习出对应合理的向量表示。然而,这种模型存在的问题在于:当输入序列非常长时,模型难以学到合理的向量表示。
在这篇博文中,我们将探索加入LSTM/RNN模型中的attention机制是如何克服传统编码器-解码器结构存在的问题的。
通过阅读这篇博文,你将会学习到:
- 传统编码器-解码器结构存在的问题及如何将输入序列编码成固定的向量表示;
- Attention机制是如何克服上述问题的,以及在模型输出时是如何考虑输出与输入序列的每一项关系的;
- 基于attention机制的LSTM/RNN模型的5个应用领域:机器翻译、图片描述、语义蕴涵、语音识别和文本摘要。
让我们开始学习吧。
一、长输入序列带来的问题
使用传统编码器-解码器的RNN模型先用一些LSTM单元来对输入序列进行学习,编码为固定长度的向量表示;然后再用一些LSTM单元来读取这种向量表示并解码为输出序列。
采用这种结构的模型在许多比较难的序列预测问题(如文本翻译)上都取得了最好的结果,因此迅速成为了目前的主流方法。
例如:
- Sequence to Sequence Learning with Neural Networks, 2014
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, 2014
这种结构在很多其他的领域上也取得了不错的结果。然而,它存在一个问题在于:输入序列不论长短都会被编码成一个固定长度的向量表示,而解码则受限于该固定长度的向量表示。
这个问题限制了模型的性能,尤其是当输入序列比较长时,模型的性能会变得很差(在文本翻译任务上表现为待翻译的原始文本长度过长时翻译质量较差)。
“一个潜在的问题是,采用编码器-解码器结构的神经网络模型需要将输入序列中的必要信息表示为一个固定长度的向量,而当输入序列很长时则难以保留全部的必要信息(因为太多),尤其是当输入序列的长度比训练数据集中的更长时。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
二、使用Attention机制
Attention机制的基本思想是,打破了传统编码器-解码器结构在编解码时都依赖于内部一个固定长度向量的限制。
Attention机制的实现是通过保留LSTM编码器对输入序列的中间输出结果,然后训练一个模型来对这些输入进行选择性的学习并且在模型输出时将输出序列与之进行关联。
换一个角度而言,输出序列中的每一项的生成概率取决于在输入序列中选择了哪些项。
“在文本翻译任务上,使用attention机制的模型每生成一个词时都会在输入序列中找出一个与之最相关的词集合。之后模型根据当前的上下文向量 (context vectors) 和所有之前生成出的词来预测下一个目标词。
… 它将输入序列转化为一堆向量的序列并自适应地从中选择一个子集来解码出目标翻译文本。这感觉上像是用于文本翻译的神经网络模型需要“压缩”输入文本中的所有信息为一个固定长度的向量,不论输入文本的长短。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
虽然模型使用attention机制之后会增加计算量,但是性能水平能够得到提升。另外,使用attention机制便于理解在模型输出过程中输入序列中的信息是如何影响最后生成序列的。这有助于我们更好地理解模型的内部运作机制以及对一些特定的输入-输出进行debug。
“论文提出的方法能够直观地观察到生成序列中的每个词与输入序列中一些词的对齐关系,这可以通过对标注 (annotations) 权重参数可视化来实现…每个图中矩阵的每一行表示与标注相关联的权重。由此我们可以看出在生成目标词时,源句子中的位置信息会被认为更重要。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
三、大型图片带来的问题
被广泛应用于计算机视觉领域的卷积神经网络模型同样存在类似的问题: 对于特别大的图片输入,模型学习起来比较困难。
由此,一种启发式的方法是将在模型做预测之前先对大型图片进行某种近似的表示。
“人类的感知有一个重要的特性是不会立即处理外界的全部输入,相反的,人类会将注意力专注于所选择的部分来得到所需要的信息,然后结合不同时间段的局部信息来建立一个内部的场景表示,从而引导眼球的移动及做出决策。”
这种启发式方法某种程度上也可以认为是考虑了attention,但在这篇博文中,这种方法并不认为是基于attention机制的。
基于attention机制的相关论文如下:
- Recurrent Models of Visual Attention, 2014
- DRAW: A Recurrent Neural Network For Image Generation, 2014
- Multiple Object Recognition with Visual Attention, 2014
四、基于Attention模型的应用实例
这部分将列举几个具体的应用实例,介绍attention机制是如何用在LSTM/RNN模型来进行序列预测的。
1. Attention在文本翻译任务上的应用
文本翻译这个实例在前面已经提过了。
给定一个法语的句子作为输入序列,需要输出翻译为英语的句子。Attention机制被用在输出输出序列中的每个词时会专注考虑输入序列中的一些被认为比较重要的词。
我们对原始的编码器-解码器模型进行了改进,使其有一个模型来对输入内容进行搜索,也就是说在生成目标词时会有一个编码器来做这个事情。这打破了之前的模型是基于将整个输入序列强行编码为一个固定长度向量的限制,同时也让模型在生成下一个目标词时重点考虑输入中相关的信息。
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
Attention在文本翻译任务(输入为法语文本序列,输出为英语文本序列)上的可视化(图片来源于Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015)
2. Attention在图片描述上的应用
与之前启发式方法不同的是,基于序列生成的attention机制可以应用在计算机视觉相关的任务上,帮助卷积神经网络重点关注图片的一些局部信息来生成相应的序列,典型的任务就是对一张图片进行文本描述。
给定一张图片作为输入,输出对应的英文文本描述。Attention机制被用在输出输出序列的每个词时会专注考虑图片中不同的局部信息。
我们提出了一种基于attention的方法,该方法在3个标准数据集上都取得了最佳的结果……同时展现了attention机制能够更好地帮助我们理解模型地生成过程,模型学习到的对齐关系与人类的直观认知非常的接近(如下图)。
— Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016
Attention在图片描述任务(输入为图片,输出为描述的文本)上的可视化(图片来源于Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016)
3. Attention在语义蕴涵 (Entailment) 中的应用
给定一个用英文描述的前提和假设作为输入,输出假设与前提是否矛盾、是否相关或者是否成立。
举个例子:
前提:在一个婚礼派对上拍照
假设:有人结婚了
该例子中的假设是成立的。
Attention机制被用于关联假设和前提描述文本之间词与词的关系。
我们提出了一种基于LSTM的神经网络模型,和把每个输入文本都独立编码为一个语义向量的模型不同的是,该模型同时读取前提和假设两个描述的文本序列并判断假设是否成立。我们在模型中加入了attention机制来找出假设和前提文本中词/短语之间的对齐关系。……加入attention机制能够使模型在实验结果上有2.6个点的提升,这是目前数据集上取得的最好结果…
Attention在语义蕴涵任务(输入是前提文本,输出是假设文本)上的可视化(图片来源于Reasoning about Entailment with Neural Attention, 2016)
4. Attention在语音识别上的应用
给定一个英文的语音片段作为输入,输出对应的音素序列。
Attention机制被用于对输出序列的每个音素和输入语音序列中一些特定帧进行关联。
…一种基于attention机制的端到端可训练的语音识别模型,能够结合文本内容和位置信息来选择输入序列中下一个进行编码的位置。该模型有一个优点是能够识别长度比训练数据长得多的语音输入。
Attention在语音识别任务(输入是音帧,输出是音素的位置)上的可视化(图片来源于Attention-Based Models for Speech Recognition, 2015)
5. Attention在文本摘要上的应用
给定一篇英文文章作为输入序列,输出一个对应的摘要序列。
Attention机制被用于关联输出摘要中的每个词和输入中的一些特定词。
… 在最近神经网络翻译模型的发展基础之上,提出了一个用于生成摘要任务的基于attention的神经网络模型。通过将这个概率模型与一个生成式方法相结合来生成出准确的摘要。
— A Neural Attention Model for Abstractive Sentence Summarization, 2015
Attention在文本摘要任务(输入为文章,输出为文本摘要)上的可视化(图片来源于A Neural Attention Model for Abstractive Sentence Summarization, 2015)
五、Attention的数学解释
1. 原来的Encoder–Decoder
在这个模型中,encoder只将最后一个输出递给了decoder,这样一来,decoder就相当于对输入只知道梗概意思,而无法得到更多输入的细节,比如输入的位置信息。所以想想就知道了,如果输入的句子比较短、意思比较简单,翻译起来还行,长了复杂了就做不好了嘛。
2. 对齐问题
前面说了,只给我递来最后一个输出,不好;但如果把每个step的输出都传给我,又有一个问题了,怎么对齐?
什么是对齐?比如说英文翻译成中文,假设英文有10个词,对应的中文翻译只有6个词,那么就有了哪些英文词对哪些中文词的问题了嘛。
传统的翻译专门有一块是搞对齐的,是一个比较独立的task(传统的NLP基本上每一块都是独立的task啦)。
3. attention机制
我们从输出端,即decoder部分,倒过来一步一步看公式。
$$ S_t=f(S_{t-1}, y_{t-1}, c_t) \tag{1} $$
$S_t$是指decoder在$t$时刻的状态输出,$S_{t-1}$是指decoder在$t-1$时刻的状态输出,$y_{t-1}$是$t-1$时刻的label(注意是label,不是我们输出的$y$),$c_t$看下一个公式,$f$是一个RNN。
$$ {c_{t}} = \sum\limits_{j = 1}^{{T_x}} {{a_{tj}}{h_j}} \tag{2} $$
$h_j$是指第$j$个输入在encoder里的输出,$a_{tj}$是一个权重
$$ {a_{tj}} = \frac{{exp \left( {{e_{tj}}} \right)}}{{\sum\nolimits_{k = 1}^{{T_x}} {exp \left( {{e_{tk}}} \right)} }} \tag{3}$$
这个公式跟softmax是何其相似,道理是一样的,是为了得到条件概率$P(a|e)$,这个$a$的意义是当前这一步decoder对齐第$j$个输入的程度。
最后一个公式,
$$ e_{tj} = g(S_{t-1}, h_j) = V\cdot \tanh { \left( W\cdot h_j+U\cdot S_{t-1}+b \right) } \tag{4}$$
这个$g$可以用一个小型的神经网络来逼近,它用来计算$S_{t-1}$, $h_j$这两者的关系分数,如果分数大则说明关注度较高,注意力分布就会更加集中在这个输入单词上,这个函数在文章Neural Machine Translation by Jointly Learning to Align and Translate(2014)中称之为校准模型(alignment model),文中提到这个函数是RNN前馈网络中的一系列参数,在训练过程会训练这些参数, 基于Attention-Based LSTM模型的文本分类技术的研究(2016)给出了上式的右侧部分作为拓展。
好了,把四个公式串起来看,这个attention机制可以总结为一句话:当前一步输出$S_t$应该对齐哪一步输入,主要取决于前一步输出$S_{t-1}$和这一步输入的encoder结果$h_j$。
看了这个方法的感受是,计算力发达的这个年代,真是什么复杂的东西都有人敢试了啊。这要是放在以前,得跑多久才能收敛啊......
神经网络搞NLP虽然还有诸多受限的地方,但这种end-to-end 的one task方式,太吸引人,有前途。
进一步的阅读
如果你想进一步地学习如何在LSTM/RNN模型中加入attention机制,可阅读以下论文:
- Attention and memory in deep learning and NLP
- Attention Mechanism
- Survey on Attention-based Models Applied in NLP
- What is exactly the attention mechanism introduced to RNN? (来自Quora)
- What is Attention Mechanism in Neural Networks?
目前Keras官方还没有单独将attention模型的代码开源,下面有一些第三方的实现:
- Deep Language Modeling for Question Answering using Keras
- Attention Model Available!
- Keras Attention Mechanism
- Attention and Augmented Recurrent Neural Networks
- How to add Attention on top of a Recurrent Layer (Text Classification)
- Attention Mechanism Implementation Issue
- Implementing simple neural attention model (for padded inputs)
- Attention layer requires another PR
- seq2seq library
总结
通过这篇博文,你应该学习到了attention机制是如何应用在LSTM/RNN模型中来解决序列预测存在的问题。
具体而言,采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题:不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差)。而attention机制则克服了上述问题,原理是在模型输出时会选择性地专注考虑输入中的对应相关的信息。使用attention机制的方法被广泛应用在各种序列预测任务上,包括文本翻译、语音识别等。
感谢原作者Jason Brownlee。原文链接见:Attention in Long Short-Term Memory Recurrent Neural Networks
转载:http://www.jeyzhang.com/understand-attention-in-rnn.html
Deep Learning基础--理解LSTM/RNN中的Attention机制的更多相关文章
- 理解LSTM/RNN中的Attention机制
转自:http://www.jeyzhang.com/understand-attention-in-rnn.html,感谢分享! 导读 目前采用编码器-解码器 (Encode-Decode) 结构的 ...
- Deep Learning基础--理解LSTM网络
循环神经网络(RNN) 人们的每次思考并不都是从零开始的.比如说你在阅读这篇文章时,你基于对前面的文字的理解来理解你目前阅读到的文字,而不是每读到一个文字时,都抛弃掉前面的思考,从头开始.你的记忆是有 ...
- LSTM/RNN中的Attention机制
一.解决的问题 采用传统编码器-解码器结构的LSTM/RNN模型存在一个问题,不论输入长短都将其编码成一个固定长度的向量表示,这使模型对于长输入序列的学习效果很差(解码效果很差). 注意下图中,ax ...
- Deep Learning基础--CNN的反向求导及练习
前言: CNN作为DL中最成功的模型之一,有必要对其更进一步研究它.虽然在前面的博文Stacked CNN简单介绍中有大概介绍过CNN的使用,不过那是有个前提的:CNN中的参数必须已提前学习好.而本文 ...
- Deep Learning基础--参数优化方法
1. 深度学习流程简介 1)一次性设置(One time setup) -激活函数(Activation functions) - 数据预处理(Data Preprocessing) ...
- 深度学习中的Attention机制
1.深度学习的seq2seq模型 从rnn结构说起 根据输出和输入序列不同数量rnn可以有多种不同的结构,不同结构自然就有不同的引用场合.如下图, one to one 结构,仅仅只是简单的给一个输入 ...
- Deep Learning基础--word2vec 中的数学原理详解
word2vec 是 Google 于 2013 年开源推出的一个用于获取 word vector 的工具包,它简单.高效,因此引起了很多人的关注.由于 word2vec 的作者 Tomas Miko ...
- Deep Learning基础--随时间反向传播 (BackPropagation Through Time,BPTT)推导
1. 随时间反向传播BPTT(BackPropagation Through Time, BPTT) RNN(循环神经网络)是一种具有长时记忆能力的神经网络模型,被广泛用于序列标注问题.一个典型的RN ...
- 对循环神经网络参数的理解|LSTM RNN Input_size Batch Sequence
在很多博客和知乎中我看到了许多对于pytorch框架中RNN接口的一些解析,但都较为浅显甚至出现一些不准确的理解,在这里我想阐述下我对于pytorch中RNN接口的参数的理解. 我们经常看到的RNN网 ...
随机推荐
- 【bzoj3992】[SDOI2015]序列统计 原根+NTT
题目描述 求长度为 $n$ 的序列,每个数都是 $|S|$ 中的某一个,所有数的乘积模 $m$ 等于 $x$ 的序列数目模1004535809的值. 输入 一行,四个整数,N.M.x.|S|,其中|S ...
- Xmind8破解,以及相关的流程和破解包
一.下载XMindCrack.jar文件:(传的貌似被屏蔽了:如果需要请留下邮箱,抽空会发给你) 百度云 ,里面破解文件,安装包都给了,但Xmind安装包不一定是最新的,有需求的可自行去官网下载 . ...
- 【HLSDK系列】HL引擎入门篇
如果你打算拿HL的源码(也就是HLSDK)来改出一个自己的游戏,那你就非常有必要理解一些HL引擎的工作方式. HL引擎分成两个部分,服务端和客户端.服务端管理所有玩家的状态和游戏规则,客户端负责显示U ...
- 洛谷 P1576 最小花费
题目戳 题目描述 在n个人中,某些人的银行账号之间可以互相转账.这些人之间转账的手续费各不相同.给定这些人之间转账时需要从转账金额里扣除百分之几的手续费,请问A最少需要多少钱使得转账后B收到100元. ...
- 51nod1238 最小公倍数之和 V3 莫比乌斯函数 杜教筛
题意:求\(\sum_{i = 1}^{n}\sum_{j = 1}^{n}lcm(i, j)\). 题解:虽然网上很多题解说用mu卡不过去,,,不过试了一下貌似时间还挺充足的,..也许有时间用phi ...
- Implement Queue by Two Stacks
As the title described, you should only use two stacks to implement a queue's actions. The queue sho ...
- 【bzoj4002】有意义的字符串
Portal --> bzoj4002 Solution 虽然说这题有点强行但是感觉还是挺妙的,给你通项让你反推数列的这种==有点毒 补档时间 首先有一个东西叫做特征方程,我们可以 ...
- Redis基操
Redis key-value类型的缓存数据库 指定IP和端口连接redis: ./redis-cli -h ip -p port Redis基本操作命令 命令 返回值 简介 ping PONG 测试 ...
- webstorm 激活破解方法大全
webstorm 作为最近最火的前端开发工具,也确实对得起那个价格,但是秉着勤俭节约的传统美德,我们肯定是能省则省啊. 方法一:(更新时间:2018/1/23)v3.3 注册时,在打开的License ...
- 二型错误和功效(Type II Errors and Test Power)
sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频教程) https://study.163.com/course/introduction.htm?courseId=1005269003&am ...