简单而言,seq2seq由两个RNN组成,一个是编码器(encoder),一个是解码器(decoder).以MT为例,将源语言“我爱中国”译为“I love China”,则定义序列:
\[
X=(x_0,x_1,x_2,x_3)\\
其中,x_0=“我”,x_1=“爱”,x_2=“中”,x_3=“国”
\]
另外目标序列:
\[
Y=(y_0,y_1,y_2)="I\ love\ China"
\]
通过编码器将\(X=(x_0,x_1,x_2,x_3)\)映射为隐层状态\(h\),再经由解码器将\(h\)映射为\(Y=(y_0,y_1,y_2)\)

通常使用\(h​\)表示编码器的隐状态;用\(s​\)表示解码器的隐状态

注意:编码器输入和解码器输出向量的维度可以不同,最后将预测T和真实目标序列T‘做loss(通常是交叉熵)训练网络。

注意力机制

通过编码器,把\(X=(x_1,x_2,x_3,x_4)\)映射为一个隐层状态\(H=(h_0,h_1,h_2,h_3)\),解码器将\(H=(h_0,h_1,h_2,h_3)\)映射为\(Y=(y_0,y_1,y_2)\)。在带注意力机制的编解码器中,\(Y\)中的每一个元素都与\(H\)中的所有元素相连,而解码器的每个元素通过不同的权值给予编码器输出\(Y\)不同的贡献。

解码器输出有3个:

  • 上一解码步的隐状态(\(s_{t-1}\))

  • 上一解码步的输出(\(y_{t-1}\))

  • 注意力输出(编码器输出的加权和,context,是编码器端发给解码器信息的地方,由所有的编码器输出得到一个定长的向量,代表输入序列的全局信息,作为当前解码步的上下文),计算方法为:
    \[
    c_i=\sum_{j=1}^{T_x}\alpha_{ij}h_j
    \]
    其中,\(\alpha_{ij}\)是权重(\(\alpha_{ij}\)是标量,\(\alpha\)是二阶张量),又称作alignment;\(h\)是编码器所有时间步上的隐状态,又称作value或memory;\(i\)表示解码步,\(j\)表示编码步,输出\(c_i\)是和\(h_j\)同样大小的向量。

    在时间\(i\)上,
    \[
    c_i=\sum_{j=1}^{T_x}\alpha_{ij}h_j=\alpha_{i,1}h_1+\alpha_{i,2}h_2+...+\alpha_{i,T_x}h_{T_x}
    \]
    其中,\(c_i\)是与编码器输出\(h_j\)等大的向量;\(j\)为编码步;\(i\)为解码步;\(\alpha_{ij}\)为标量,计算方式:
    \[
    \alpha_{ij}=\frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})}
    \]
    其中,\(e_{ij}=a(s_{i-1,j},h_j)\),表征\(s_{i-1}\)和\(h_j\)的相关程度,即对于某个给定的解码步,计算上一解码步的隐状态和所有编码步输出的相关程度,并且用softmax做归一化。这样,与上一解码步状态相关度大的编码器输出\(h\)的权重就大,在本解码步的整个上下文里面所占的比重就多,解码器在本时间步上解码时就越依赖这个编码器的输出\(h\).

    \(e_{ij}\)又被称作能量函数,\(a(·)\)的计算方法:

    1. 对\(s_{i-1}\)做线性映射,得到向量作为query(解码器上一时间步隐状态作为“查询”),记作\(q_i\)
    2. 对\(h_j\)做线性映射,得到向量作为key(编码器每一个时间步上的结果作为key待查),记作\(k_j\)
    3. \(e_{ij}=v^T(q_i+k_j)\),\(q_i\)和\(k_j\)的维度必须相同,同为d维;\(v\)是一个d×1的向量,从而得到的\(e_{ij}\)是一个标量

    1、2中的线性映射都是待训练的,3中的\(v\)也是待训练的。

    对query和key求相关性从而获得权重(alignment),用该权重对value加权和从而得到上下文送入解码器。

    小结

    • 3中query和key做加法,之后通过一个权重变为标量。这被称作“加性注意力”,相应的,可以做元素乘,被称作“乘性注意力”
    • location-sensitive,认为相邻\(\alpha_{ij}\)之间的关系会相对较大,为了捕获这种关系对alignment进行了卷积。
    • query有多种,不仅仅有上一解码步的隐状态;也有当前解码步的隐状态;还有将上一解码步上的隐状态和上一解码步的输出拼接作为query。但在TTS中,将上一解码步的隐状态和输出拼接作为query并不好,原因可能是可能两者不在同一空间,因此要具体问题具体分析。

Transformer

  • 左右分别是编码器和解码器

  • 编码器和解码器的底部都是embedding,而embedding又分为两部分:input embeddingpositional embedding,其中input embedding就是NLP中常见的词嵌入。因为Transformer中只有attention,对于一对(query, key),无论这对query-key处在什么位置,其计算都是相同的。不像CNN或RNN有一个位置或时序的差异:CNN框住的是一块区域,随着卷积核的移动,卷积核边缘的点也随着有序变化;RNN则更为明显,不同时序的\(h_t\)和\(s_t\)不同,而且是随着输入顺序(正/倒序)而不同。

    因此Transformer为了体现出时序或者序列中的位置差异,要对input加入一定的位置信息,这即是position embedding。求位置id为pos的位置编码向量:
    \[
    \left\{\begin{matrix}
    PE(pos,i)=sin(\frac{pos}{10000^{\frac{i}{d_{model}}}}),\ 若i为奇数
    \\
    PE(pos,i)=cos(\frac{pos}{10000^{\frac{i}{d_{model}}}})\ 若i为偶数
    \end{matrix}\right.
    PE向量第i维的求解方法
    \]
    编码器和解码器输入序列shape: \([T,d_{model}]\),即每个时刻的\(x_i\)都是\(d_{model}\)维的,因此\(pos\in [0,T]\),\(i\in[0,d_{model}]\)。即对于输入的\([T,d_{model}]\)的一个张量,其中的每一个标量都对应一个独特的编码结果,可以理解为给embedding一个低频信号,让其周期性波动,而且每个维度波动都不相同,以表征其id信息。

  • 编码器和解码器的中部分别是两个块,分别输入一个序列,输出一个序列,这两个块重复N次。编码器的每个块里有两个子网,分别为Multi-Head Attention和Feed Forward Network(FFN);解码器的每个块里有三个子网,分别是2个Multi-Head Attention和一个FFN。这些子网之后都跟一个add & norm,就是像ResNet那样做一个残差,然后加一个layer normalization。

  • 解码器最后还有个linear和softmax

FFN

FFN就是对一个输入序列\(X=(x_0,x_1,...,x_T)\),对每一个\(x_i\)都进行一次channel的重组:512 -> 2048 -> 512,可以理解为对每个\(x_i\)进行两次线性映射,也可以对整个序列进行1×1卷积。

Multi-Head Attention

原始的attention就是一个query(Q)和一组key(K)算相似度,然后对一组value(V)做加权和。假如每个Q和K都是512维的向量,就相当于在512维的空间里比较两个向量的相似度。而Multi-Head相当于加过于512维的空间人为拆分为多个子空间,如head number=8就是将高维空间拆分为8个子空间,相应地V也要分为8个head,然后在这8个子空间中分别计算Q和K的相似度,再组合V。这样能使attention从不同角度捕获序列关系。

  • 编码器
    \[
    sub\_layer\_output=LayerNorm(x+SubLayer(x))\\
    head_i=Attention(QW_i^Q,KW_i^K,VW_i^K)\\
    MultiHead(Q,K,V)=concat(head_1,head_2,...,head_h)W^O
    \]
    self-attention时,Q、K、V相同

  • 解码器

    • 输入:编码器的输出 & 对应i-1时刻的解码器输出(i-1步的hidden state和i-1步的输出)

      注意:在解码器中间的attention不是self-attention,其K、V来自编码器,Q来自上一时刻的解码器输出

    • 输出:i时刻的输出词的概率分布

    • 解码:编码可以并行,一次性全部编码出来(在编码时,各个计算互不依赖)。但解码不是一次把所有序列解出来,而是如同RNN一样,一个一个解出来,因为要用到上一解码步的隐状态作为attention的query。解码器端最先的Multi-Head是Masked,这是因为训练时输入是ground truth,这样确保预测第i个位置时,遮蔽掉该位置及其之后的信息,不会接触未来的信息。

Transformer优缺点

  • 优点

    • 并行计算,这主要体现在编解码器都放弃了RNN,下一个时间步的计算不必等待之前的计算完全展开

    • 直接的长距离依赖

      原来的RNN中,第一帧要和第十帧发生关系,必须通过第二~九帧传递,进而产生两者的计算。而在这个过程中,第一帧的信息有可能已经产生了偏差,准确性和速度都难以保证。在Transformer中,由于self-attention的存在,任意两帧都有直接的交互,建立直接依赖。

  • 缺点

    仍然是自回归模型,任意一帧的输出都依赖于它之前的所有输出。比如输入abc,本次的输出实际是bcd,每输入一个序列,其实序列的末端都只是前进了一帧,因此要生成abcdefg仍然要循环6次。

seq2seq和Transformer的更多相关文章

  1. BERT解析及文本分类应用

    目录 前言 BERT模型概览 Seq2Seq Attention Transformer encoder部分 Decoder部分 BERT Embedding 预训练 文本分类试验 参考文献 前言 在 ...

  2. 【NLP】老司机带你入门自然语言处理

    自然语言处理是一门用于理解人类语言.情感和思想的技术,被称为是人工智能皇冠上的明珠. 随着深度学习发展,自然语言处理技术近年来发展迅速,在技术上表现为BERT.GPT等表现极佳的模型:在应用中表现为c ...

  3. 论文解读丨表格识别模型TableMaster

    摘要:在此解决方案中把表格识别分成了四个部分:表格结构序列识别.文字检测.文字识别.单元格和文字框对齐.其中表格结构序列识别用到的模型是基于Master修改的,文字检测模型用到的是PSENet,文字识 ...

  4. Bert不完全手册2. Bert不能做NLG?MASS/UNILM/BART

    Bert通过双向LM处理语言理解问题,GPT则通过单向LM解决生成问题,那如果既想拥有BERT的双向理解能力,又想做生成嘞?成年人才不要做选择!这类需求,主要包括seq2seq中生成对输入有强依赖的场 ...

  5. NLP学习(5)----attention/ self-attention/ seq2seq/ transformer

    目录: 1. 前提 2. attention (1)为什么使用attention (2)attention的定义以及四种相似度计算方式 (3)attention类型(scaled dot-produc ...

  6. seq2seq模型详解及对比(CNN,RNN,Transformer)

    一,概述 在自然语言生成的任务中,大部分是基于seq2seq模型实现的(除此之外,还有语言模型,GAN等也能做文本生成),例如生成式对话,机器翻译,文本摘要等等,seq2seq模型是由encoder, ...

  7. Transformer【Attention is all you need】

    前言 Transfomer是一种encoder-decoder模型,在机器翻译领域主要就是通过encoder-decoder即seq2seq,将源语言(x1, x2 ... xn) 通过编码,再解码的 ...

  8. 【译】深度双向Transformer预训练【BERT第一作者分享】

    目录 NLP中的预训练 语境表示 语境表示相关研究 存在的问题 BERT的解决方案 任务一:Masked LM 任务二:预测下一句 BERT 输入表示 模型结构--Transformer编码器 Tra ...

  9. 【译】图解Transformer

    目录 从宏观上看Transformer 把张量画出来 开始编码! 从宏观上看自注意力 自注意力的细节 自注意力的矩阵计算 "多头"自注意力 用位置编码表示序列的顺序 残差 解码器 ...

随机推荐

  1. 【t007】棋盘放置指南车问题

    Time Limit: 1 second Memory Limit: 50 MB [问题描述] 按照国际象棋的规则,车可以攻击与之处在同一行或同一列上的棋子.指南车是有方向的车.横向指南车可以攻击与之 ...

  2. 探险 - 树型dp(背包)/多叉树转二叉树

    题目大意: 国家探险队长 Jack 意外弄到了一份秦始皇的藏宝图,于是,探险队一行人便踏上寻宝之旅,去寻找传说中的宝藏. 藏宝点分布在森林的各处,每个点有一个值,表示藏宝的价值.它们之间由一些小路相连 ...

  3. BZOJ 1509 逃学的小孩 - 树型dp

    传送门 题目大意: 在一棵树中, 每条边都有一个长度值, 现要求在树中选择 3 个点 X.Y. Z , 满足 X 到 Y 的距离不大于 X 到 Z 的距离, 且 X 到 Y 的距离与 Y 到 Z 的距 ...

  4. 简单的JAVA MVC框架模式--Java-servlet-JavaBean

    MVC全名是Model View Controller,是模型(model)-视图(view)-控制器(controller)的缩写,一种软件设计典范,用一种业务逻辑.数据.界面显示分离的方法组织代码 ...

  5. qLibc 对于C C++都是一个很好的框架,提供Tree Hash Stack String I/O File Time等功能

    qLibc Copyright qLibc is published under 2-clause BSD license known as Simplified BSD License. Pleas ...

  6. twemproxy

    twemproxy架构分析——剖析twemproxy代码前编   twemproxy背景 在业务量剧增的今天,单台高速缓存服务器已经无法满足业务的需求, 而相较于大容量SSD数据存储方案,缓存具备速度 ...

  7. Method of packet transmission from node and content owner in content-centric networking

    A method of transmitting a content reply packet from a content owner in content-centric networking ( ...

  8. 1 tcp/ip协议

    该协议是一个协议族,并是说具体某个协议下图中的协议都属于tcp/ip协议.他是用来规范互联网中电脑间数据传输的. 该协议可以分为4层或者7层 4层,实际层数: 链路层 网络层 传输层 应用层 7层,理 ...

  9. python3 操作注册表

    1.1 读取 import winreg key = winreg.OpenKey(winreg.HKEY_CURRENT_USER,r"Software\Microsoft\Windows ...

  10. 使用 install.packages() 安装所需的包

    1. 从 CRAN 上安装 install.packages("tm", dependencies = TRUE) tm 程序包用于文本挖掘(text mining) 2. 本地安 ...