学一学Transfomer
017年,Google发表论文《Attention is All You Need》,提出经典网络结构Transformer,全部采用Attention结构的方式,代替了传统的Encoder-Decoder框架必须结合CNN或RNN的固有模式。并在两项机器翻译任务中取得了显著效果。该论文一经发出,便引起了业界的广泛关注,同时,Google于2018年发布的划时代模型BERT也是在Transformer架构上发展而来。所以,为了之后学习的必要,本文将详细介绍Transformer模型的网络结构。
一 理论基础
1、整体架构
Transformer作为seq2seq,也是由经典的Encoder-Decoder模型组成。在上图中,整个Encoder层由6个左边Nx部分的结构组成。整个Decoder由6个右边Nx部分的框架组成,Decoder输出的结果经过一个线性层变换后,经过softmax层计算,输出最终的预测结果。
(1)、Encoder结构:
输入序列X经过word embedding和positional encoding做直接加和后,作为Encoder部分的输入。输入向量经过一个multi-head self-attention层后,做一次residual connection(残差连接)和Layer Normalization(层归一化,下文中简称LN),输入到下一层position-wise feed-forward network中。之后再进行一次残差连接+LN,输出到Decoder部分,这里所涉及到的相关知识会在下文中详细介绍。
(2)、Decoder结构:
输出序列Y经过word embedding和positional encoding做直接加和后,作为Decoder部分的输入。很多对seq2seq不了解的朋友看到这里可能有些糊涂,简单说明以下。以翻译任务为例,假设我们要进行一个中译英任务。我们现在有一段中文序列X,对应的英文序列Y。我们在翻译出某个单词Yt时,并非只是用中文序列X翻译,而是用中文序列X加已经翻译出来的英文序列(y1,y2,……yt-1)进行翻译,所以也要将已经翻译出来的英文序列输入其中。这也就解释了为什么会将输出序列Y作为Decoder的输入。在论文中,在训练过程中为了处理方便同时不引入未来信息,采用了一种sequence masking机制,具体的实现下文再详细介绍。
Decoder部分的输入向量首先经过一层multi-head self-attention,进行一次残差连接+LN,再经过一层multi-head context-attention,进行一次残差连接+LN,最后再经过一层position-wise feed-forward network,进行一次残差连接+LN后,输出至线性层。
以上介绍了Encoder和Decoder的基本流程,相信大家对其中具体的实现还有不明白的细节,下面我就来为大家一一阐述。
2、Attention机制
上文中提到了两个Attenton结构,multi-head self-attention和multi-head context-attention可以说是本文中最重要的概念,这里来解释下两者的实现,首先,我们来回顾以下基础的Attention机制。
(1)、基础Attention机制
之前曾经写过一篇详细介绍Attention的文章,感兴趣的朋友可以关注我的公众号查找,这里主要使用论文中描述的方式来简单介绍以下基础Attention。
在自然语言处理中,Attention的本质可以理解为一个查询(query)到一些列(key - value)对的映射。以基础的Attention计算公式为例:
计算attention时:第一步,将query和每个key进行相似度计算得到权重,即上图中的第三个公式。第二步,一般使用一个softmax函数将这些权重进行归一化,即上图中的第二个公式,最后将权重和相应的键值value进行加权求和,得到最终的attention,即第一个公式。通常key和value取值相同,例如上图中,key=value=hj, query=si-1。
其实,Google所用到的基本attention思路是与上面一致的,只是在计算Attention分数时,采用了另一种计算机制:Scaled dot-product attention
(2)、Scaled dot-product attention
Scaled dot-product attention的计算公式如下:
其实基本元素还是Q,K,V三项,无非就是公式变了下。具体的计算图结构文章中也给了图,公式很清晰这里就不列了。
(3)Self-attention 和Context-attention
Self-attention:自己跟自己做Attention,输入序列=输出序列。Q=K=V。
Context-attention:Encoder输出结果跟Decoder第一部分输出结果之间做Attention。
具体到网络结构中:
Encoder中的self-attention,Q,K,V均为Encoder的输入。
Decoder中的self-attention,Q,K,V均为Decoder的输入,也就是上一层Decoder的输入,具体原因见Decoder的介绍。
Decoder中context-attention,Q为decoder中第一部分的输出,K,V均为encoder的输出。
(4)、Multi-head attention
论文中采用的Multi-head attention,就是将Q, K, V先经过一个线性映射,再在在输入维度dk,dq,dv上切分成h份,再对每一份进行Scaled dot-product attention,之后将每部分结果合并起来,经过线性映射,得到最终的输出,结构图如下:
说的有些绕,举个例子,原文中d=512(即词向量和位置向量的维度),h=8。那么假设原始输入为[batch_size*seq_len*512]的三维表,处理后共分成8份[batch_size*seq_len* 64]的三维表,每份分别做Scaled dot-product,就是Multi-head attention了。这样进行了h次运算,可以允许模型在不同的表示子空间中学习到相关信息。
以上就是Attention部分的全部讲解,说清楚这一部分,其他的都是一些零碎的细节。
3、Position-wise Feed-Forward network
一个全联接神经网络,先进行一次线性变换,再通过一次ReLU激活函数,最后再进行一次线性变化。公式如下:
4、Positional encoding
位置编码,顾名思义,对序列中词语的位置进行编码,公式如下:
即奇数位置用余弦编码,偶数位置用正弦编码,最终得到一个512维的位置向量。
5、Residual connection
残差连接其实在很多网络机构中都有用到。原理很简单,假设一个输入向量x,经过一个网络结构,得到输出向量f(x),加上残差连接,相当于在输出向量中加入输入向量,即输出结构变为f(x)+x,这样做的好处是在对x求偏导时,加入一项常数项1,避免了梯度消失的问题。
6、Layer Normalization
归一化的本质都是将数据转化为均值为0,方差为1的数据。这样可以减小数据的偏差,规避训练过程中梯度消失或爆炸的情况。我们在训练网络中比较常见的归一化方法是Batch Normalization,即在每一层输出的每一批数据上进行归一化。而Layer Normalization与BN稍有不同,即在每一层输出的每一个样本上进行归一化。
7、Mask
mask的思想非常简单:就是对输入序列中没某些值进行掩盖,使其不起作用。在论文中,做multi-head attention的地方用到了padding mask,在decode输入数据中用到了sequence mask。
(1)、padding mask
在我们输入的数据中,因为每句话的长度不同,所以要对较短的数据进行填充补齐长度。而这些填充值并没有什么作用,为了减少填充数据对attention计算的影响,采用padding mask的机制,即在填充物的位置上加上一个趋紧于负无穷的负数,这样经过softmax计算后这些位置的概率会趋近于0
(2)、sequence mask
在上文中我们提到,预测t时刻的输出值yt,应该使用全部的输入序列X,和t时刻之前的输出序列(y1,y2,……,yt-1)进行预测。所以在训练时,应该将t-1时刻之后的信息全部隐藏掉。所以需要用到sequence mask。
实现也很简单,就是用一个上三角矩阵,上三角值均为1,下三角值均为0,对角线值为0,与输入序列相乘,就达到了目的。
以上就是Transformer框架的全部知识点,BERT模型也是在此基础上发展而来。
二: 例子
先来看一个翻译的例子“I arrived at the bank after crossing the river” 这里面的bank指的是银行还是河岸呢,这就需要我们联系上下文,当我们看到river之后就应该知道这里bank很大概率指的是河岸。在RNN中我们就需要一步步的顺序处理从bank到river的所有词语,而当它们相距较远时RNN的效果常常较差,且由于其顺序性处理效率也较低。Self-Attention则利用了Attention机制,计算每个单词与其他所有单词之间的关联,在这句话里,当翻译bank一词时,river一词就有较高的Attention score。利用这些Attention score就可以得到一个加权的表示,然后再放到一个前馈神经网络中得到新的表示,这一表示很好的考虑到上下文的信息。如下图所示,encoder读入输入数据,利用层层叠加的Self-Attention机制对每一个词得到新的考虑了上下文信息的表征。Decoder也利用类似的Self-Attention机制,但它不仅仅看之前产生的输出的文字,而且还要attend encoder的输出。
三: 对自注意力机制的感性认识
A: 从宏观视角看自注意力机制
例如,下列句子是我们想要翻译的输入句子:
The animal didn't cross the street because it was too tired
这个“it”在这个句子是指什么呢?它指的是street还是这个animal呢?这对于人类来说是一个简单的问题,但是对于算法则不是。
当模型处理这个单词“it”的时候,自注意力机制会允许“it”与“animal”建立联系。
随着模型处理输入序列的每个单词,自注意力会关注整个输入序列的所有单词,帮助模型对本单词更好地进行编码。
如果你熟悉RNN(循环神经网络),回忆一下它是如何维持隐藏层的。RNN会将它已经处理过的前面的所有单词/向量的表示与它正在处理的当前单词/向量结合起来。而自注意力机制会将所有相关单词的理解融入到我们正在处理的单词中。
当我们在编码器#5(栈中最上层编码器)中编码“it”这个单词的时,注意力机制的部分会去关注“The Animal”,将它的表示的一部分编入“it”的编码中。
请务必检查Tensor2Tensor notebook ,在里面你可以下载一个Transformer模型,并用交互式可视化的方式来检验。
B: 从微观视角看自注意力机制
首先我们了解一下如何使用向量来计算自注意力,然后来看它实怎样用矩阵来实现。
计算自注意力的第一步就是从每个编码器的输入向量(每个单词的词向量)中生成三个向量。也就是说对于每个单词,我们创造一个查询向量、一个键向量和一个值向量。这三个向量是通过词嵌入与三个权重矩阵后相乘创建的。
可以发现这些新向量在维度上比词嵌入向量更低。他们的维度是64,而词嵌入和编码器的输入/输出向量的维度是512. 但实际上不强求维度更小,这只是一种基于架构上的选择,它可以使多头注意力(multiheaded attention)的大部分计算保持不变。
X1与WQ权重矩阵相乘得到q1, 就是与这个单词相关的查询向量。最终使得输入序列的每个单词的创建一个查询向量、一个键向量和一个值向量。
参考文献:
1 经典英文博客中文翻译: https://zhuanlan.zhihu.com/p/54356280
2 科学: https://spaces.ac.cn/archives/4765
3 AI机动队: https://zhuanlan.zhihu.com/p/47282410
4 深度学习中的注意力机制 https://blog.csdn.net/songbinxu/article/details/80739447
学一学Transfomer的更多相关文章
- 给大一的学弟学妹们培训java web的后台开发讨论班计划
蓝旭工作室5月大一讨论班课程计划 课时 讨论班性质 讨论班名称 主要内容 主讲人 第一讲 先导课 后台开发工具的使用与MySQL数据库基础 后台开发工具的基本使用方法与工程的创建,MySQL数 ...
- 11th 回忆整个学期——告学弟学妹
告诉后来的学弟学妹,不要因为艰难而却步,坚持下去才知道,山的对面是什么.很多东西或许一开始看起来是无用,甚至无意义的,但是努力去做,你才知道价值所在.不要等一切结束了,才懂得自己错过了什么.
- NOIP2018学军中学游记(11.09~11.11)
前言 这篇博客记录的是我在\(NOIP2018\)提高组比赛中的经历. 这一次的\(NOIP\)是在学军中学举办的, 莫名感到一阵慌张. 但愿能有一个好成绩,不然就要\(AFO\)了... ... 说 ...
- 准备学一学go-lang啦 ~~ 学习Go应该用什么姿势? !
go毕竟是新语言,没有那么多历史包袱,并且是google出品,c语言创始人为语言设计组成员,应该还是不错的. go天生具有并行能力,这个在现代服务器端编程领域作用很显而易见,高效服务,快速编码,适合互 ...
- Lucene/ElasticSearch 学习系列 (1) 为什么学,学什么,怎么学
为什么学 <What I wish I knew When I was 20>这本书给了我很多启发.作者在书中提到,Stanford 大学培养人才的目标是 ”T形人才“:精通某个领域,但对 ...
- 【转】科大校长给数学系学弟学妹的忠告&本科数学参考书
1.老老实实把课本上的题目做完.其实说科大的课本难,我以为这话不完整.科大的教材,就数学系而言还是讲得挺清楚的,难的是后面的习题.事实上做1道难题的收获是做10道简单题所不能比的. 2.每门数学必修课 ...
- 授人以鱼不如授人以渔——和女儿学一起学成语
女儿二年级了,前段时间背了<小学生必背古诗词75首>,采用几天一篇,然后滚动复习这种方式.磕磕绊绊也把一本古诗背了一遍,效果吗?是有的,但是不怎么明显,前面背,后面忘.当然这是规律,难免的 ...
- 学password学一定得学程序
题目描写叙述 以前.ZYJ同学非常喜欢password学.有一天,他发现了一个非常长非常长的字符串S1.他非常好奇那代表着什么,于是奇妙的WL给了他还有一个字符串S2.可是非常不幸的是,WL忘记跟他说 ...
- 开局一张图,学一学项目管理神器Maven!
Maven强大的Java工程构建工具,做Java开发时少了跟Maven打交道,之前在知乎上看到有人提问:"学Java开发需不需要学习Maven?",个人认为是必需要学的,这和工欲善 ...
随机推荐
- 深入理解Java虚拟机——读书笔记
首先 强烈推荐周志明老师的这本书,真的可以说是(起码中文出版界)新手了解Java虚拟机必须人手一本的教科书!!! 第二部分自动内存管理机制 由于Java虚拟机的多线程是通过线程轮流切换并分配处理器 ...
- javascript代码模块化解决方案
我们用模块化的思想进行网页的编写是为了更好的管理我们的项目,模块与模块之间是独立存在的,每个模块可以独立的完成一个子功能. 一.服务器和桌面环境中的Javascript代码模块化:CommonJS M ...
- Python3下UnicodeDecodeError:‘ASCII’ codec cant decode..(128)
今天准备用Keras跑一下LeNet的程序,结果总是编码出错 源代码是2.7写的,编码格式是utf-8.然后尝试网上各种方法不适用,最后还是解决了 源代码: data = gzip.open(r'C: ...
- flutter 跳转至根路由
上代码 //flutter 登录后跳转到根路由 Navigator.of(context).pushAndRemoveUntil( new MaterialPageRoute(builder: (co ...
- 计算机网络|C语言Socket编程,实现两个程序间的通信
C语言Socket编程,实现两个程序间的通信 server和client通信流程图 在mooc上找到的,使用Socket客户端client和服务端server通信的流程图
- redis序列化和反序列化
RedisTemplate中需要声明4种serializer,默认为“JdkSerializationRedisSerializer”: 1) keySerializer :对于普通K-V操作时,ke ...
- Java中判断两个Long类型是否相等
在项目中将两个long类型的值比较是否相等,结果却遇到了疑问? 下面就陪大家看看一个神奇的现象! 1.1问题?为什么同样的类型,同样的值,却不相等呢? 1.2那么我们就需要探索一下源码了 源码中显示, ...
- [Luogu] 计算系数
https://www.luogu.org/problemnew/show/P1313#sub Answer = a ^ n * b ^ m * C(k, min(n, m)) 这里用费马小定理求逆 ...
- neo4j︱与python结合的py2neo使用教程
—- 目前的几篇相关:—– neo4j︱图数据库基本概念.操作罗列与整理(一) neo4j︱Cypher 查询语言简单案例(二) neo4j︱Cypher完整案例csv导入.关系联通.高级查询(三) ...
- Friend-Graph (HDU 6152)2017中国大学生程序设计竞赛 - 网络选拔赛
Problem Description It is well known that small groups are not conducive of the development of a tea ...