[NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL
1. Motivation
在Transformer-XL中,由于设计了segments,如果仍采用transformer模型中的绝对位置编码的话,将不能区分处不同segments内同样相对位置的词的先后顺序。
比如对于$segment_i$的第k个token,和$segment_j$的第k个token的绝对位置编码是完全相同的。
鉴于这样的问题,transformer-XL中采用了相对位置编码。
2. Relative Positional Encodings
paper中,由对绝对位置编码变换推导出新的相对位置编码方式。
vanilla Transformer中的绝对位置编码
它对每个index的token都通过sin/cos变换,为其唯一指定了一个位置编码。该位置编码将与input的embedding求sum之后作为transformer的input。
那么如果将该位置编码应用在transformer-xl会怎样呢?
其中$\tau$表示第$\tau$个segment, 是当前segment的序列$s_{\tau}$的word embedding sequence, $L$是序列长,$d$是每个word embedding的维度。$U_{1:L}$表示该segment中每个token的绝对位置编码组成的序列。
可以看到对于$h_{\tau + 1}$和$h_{\tau}$,其在位置编码表示是完全相同的,都是$U_{1:L}$,这样就会造成motivation中所述的无法区分在不同segments中相对位置相同的tokens.
3. Transformer-XL中的相对位置编码
transformer-xl中没有采用vanilla transformer中的将位置编码静态地与embedding结合的方式;而是沿用了shaw et al.2018的相对位置编码中通过将位置信息注入到求Attention score的过程中,即将相对位置信息编码入hidden state中。
为什么要这么做呢?paper中给出的解释是:
1) 位置编码在概念上讲,是为模型提供了时间线索或者说是关于如何收集信息的"bias"。出于同样的目的,除了可以在初始的embedding中加入这样的统计上的bias, 也可以在计算每层的Attention score时加入同样的信息。
2) 以相对而非绝对的方式定义时间偏差更为直观和通用。比如对于一个query vector $q_{\tau,i}$ 与 key vectors $k_{\tau, \leq i}$做attention时,这个query 并不需要知道每一个key vector在序列中的绝对的位置来决定segment的时序。它只需要知道每一对$k_{\tau,j}$ 和其本身$q_{\tau,i}$的相对距离(比如,i - j)就足够。
因此,在实际中可以创建一个相对位置编码的encodings矩阵 $R \in \mathbb{R} ^ {L_{max} \times d}$,其中第i行 $R_i$表示两个pos(比如位置pos_q, pos_k)之间的相对距离为i. (可以参考我在参考链接3中的介绍,以下图示便是一个简单的说明例子.
但是图示中的i表示query的位置pos, 与$R_i$ 中的i不同。如果以该图示为例,当pos_q = i, pos_k = i - 4时, 相对位置为 0, 二者的相对位置编码是 $R_0$。
--------------------------------------------------------------------------------------------------
Transformer-XL的相对位置编码方式是对Shaw et al.,2018 和 Huang et al.2018提出模型的改进。它由采用绝对编码计算Attention score的表达式出发,进行了改进3项改变。
若采用绝对位置编码,hidden state的表达式为:
,
那么对应的query,key的attention score表达式为:
(应用乘法分配率, query的embedding 分别与 key的embedding, positional encoding相乘相加;之后 query的positional encoding分别与 key的embedding, positional encoding相乘相加)
(其中i是query的位置index,j是key的位置index) (WE, WU是对embedding进行linear projection的表示,细节内容可以参看attention is all you need 中对multi-head attention的介绍)
,
Transformer-XL 对上式进行了改进:
改进1) $Uj \rightarrow R_{i - j}$.
首先将 $A_{i, j} ^ {abs}$ 中的key vector的绝对位置编码 $U_j$ 替换为了相对位置编码 $R_{i - j}$ 其中 $R$是一个没有需要学习的参数的sinusoid encoding matrix,如同Vaswani et al., 2017提出的一样。
该改进既可以避免不同segments之间由于tokens在各自segment的index相同而产生的时序冲突的问题。
改进2) $(c) : U_i^{T} W_q ^ {T} \rightarrow {\color{red} u} \in \mathbb{R}^d$;$(d) : U_i^{T} W_q ^ {T} \rightarrow {\color{red} v} \in \mathbb{R}^d$
在改进1中将key的绝对位置编码转换为相对位置编码,在改进2中则对query的绝对位置编码进行了替换。因为无论query在序列中的绝对位置如何,其相对于自身的相对位置都是一样的。这说明attention bias的计算与query在序列中的绝对位置无关,应当保持不变. 所以这里将$A_{i, j} ^ {abs}$ 中的c,d项中的$U_i^{T} W_q ^ {T}$分别用一个可学习参数$u \in \mathbb{R}^d$,$v \in \mathbb{R}^d$替换。
改进3) $W_{k} \rightarrow W_{k, E}$, $W_{k, R}$
在vanilla transformer模型中,对query, key分别进行线性映射时,query 对应$W_q$矩阵,key对应$W_k$矩阵,由于input 是 embedding 与 positional encoding的相加,也就相当于
$query_{embedding} W_q + query_{pos encoding} W_q$得到query的线性映射后的表征;
$key_{embedding} W_q + key_{pos encoding} W_q$ 得到key的线性映射后的表征。
可以看出,在vanilla transformer中对于embedding和positional encoding都是采用的同样的线性变换。
在改进3中,则将key的embedding和positional encoding 分别采用了不同的线性变换。其中$W_{k,E}$对应于key的embedding线性映射矩阵,$W_{k,R}$对应与key的positional encoding的线性映射矩阵。
在这样的参数化定义后,每一项都有了一个直观上的表征含义,(a)表示基于内容content的表征,(b)表示基于content的位置偏置,(c)表示全局的content的偏置,(d)表示全局的位置偏置。
与shaw的RPR的对比
shaw的RPR可以参考我在参考链接3中的介绍。这里给出论文中的表达式:其中$a_{i,j}$是query i, key j的相对位置编码矩阵$A$中的对应编码。
attention score: (在key的表征中加入相对位置信息)
softmax计算权值系数:
attention score * (value + 的output:(在value的表征中加入相对位置信息)
1) 对于$e_{ij}$可以用乘法分配率拆解来看,那么其相当于transforerm-xl中的(a)(b)两项。也就是在shaw的模型中未考虑加入(c)(d)项的全局内容偏置和全局位置偏置。
2) 还是拆解$e_{ij}$来看,涉及到一项为$x_iW^Q(a_{ij}^K)^T$,是直接用 query的线性映射后的表征 与 相对位置编码相乘;而在transformer-xl中,则是与query的线性映射后的表征 与 相对位置编码也进行线性映射后的表征 相乘。
优势:
paper中指出,shaw et al用单一的相对位置编码矩阵 与 transformer-xl中的$W_kR$相比,丢失掉了在原始的 sinusoid positional encoding (Vaswani et al., 2017)中的归纳偏置。而XL中的这种表征方式则可以更好地利用sinusoid 的inductive bias。
----------------------------为什么XL中的这种表征方式则可以更好地利用sinusoid 的inductive bias?--------------------------------------------------------------------
有几个问题:原始的 sinusoid positional encoding (Vaswani et al., 2017)中的归纳偏置是什么呢?为什么shaw et al 把它丢失了呢?为什么transformer-xl可以适用呢?
这里需要搞清楚:
1. 为什么在vanilla transformer中使用sinusoid?
2. shaw et al.2018中的相对位置编码Tensor是什么?
3. transformer-xl的相对位置编码矩阵是什么?
对于1,sinusoid函数具有并不受限于序列长度仍可以较好表示位置信息的特点。
We chose the sinusoidal version because it may allow the model to extrapolate to sequence lengths longer than the ones encountered during training. ~Attention is all you need.
为什么不用学得参数而采用sinusoid函数呢?sinusoidal函数并不受限于序列长度,其可以在遇到训练集中未出现过的序列长度时仍能很好的“extrapolate.” (外推),这体现了其具有一些inductive bias。
对于2,shaw et al.2018中的相对位置编码Tensor是两个需要参数学习的tensor.
相对位置编码矩阵是设定长度为 2K + 1的(K是窗口大小) ,维度为$d_a$的2个tensor(分别对应与key的RPR和value的RPR),其第i行表示相对距离为i的query,key(或是query, value)的相对位置编码。这两个tensor的参数都是需要训练学习的。那么显然其是受限于最大长度的。在RPR中规定了截断的窗口大小,在遇到超出窗口大小的情况时,由于直接被截断而可能丢失信息。
对于3,transformer-xl的相对位置编码矩阵是一个sinusoid矩阵,不需要参数学习。
在transformer-xl中虽然也是引入了相对位置编码矩阵,但是这个矩阵不同于shaw et al.2018。该矩阵$R_{i,j}$是一个sinusoid encoding 的矩阵(sinusoid 是借鉴的vanilla transformer中的),不涉及参数的学习。
具体实现可以参看代码,这里展示了pytorch版本的位置编码的代码:
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq) def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]
其中$demb$是embedding的维度。
sinusoid的shape:[batch_size, seq_length × (d_emb / 2)]
sin,cos concat之后,pos_emb的shape:[batch_size, seq_length × d_emb]
pos_emb[:,None,:]之后的shape:[batch_size, 1, seq_length × d_emb]
那么综合起来看,transformer-xl的模型的hidden states表达式为:
4. 高效计算方法
在该表达式中,在计算$W_{k,R}R_{i-j}$时,需要对每一对(i,j)进行计算,时间复杂度是$O(n^2)$。paper中提出了高效的计算方法,使其降为$O(n).$
核心算法:发现(b)项组成的矩阵的行列之间的关系,构建一个矩阵,将其按行左移,恰好是(b)项矩阵$B$,而所构建的矩阵只需要$O(n)$时间。
由于相对距离(i-j)的变化范围是[0, M + L - 1] (其中M是memory的长度,L是当前segment的长度)
那么令:
那么将(b)项应用与所有的(i,j)可得一个$L \times (M + L)$的矩阵 $B$: (其中q是对E经过$W_q$映射变换后的表示)
看这些带红线的部分,是不是只有q的下标不一样!
如果我们定义$\widetilde{B}$:
对比$B$与$\widetilde{B}$发现,将$\widetilde{B}$的第i行左移 $L - 1 - i$个单位即为$B$。而$\widetilde{B}$的计算仅涉及到两个矩阵的相乘,因此$B$的计算也仅需要求$qQ^T$之后按行左移即可得到,时间复杂度降为$O(n)$!
同理,可以求(d)项的矩阵D。
这样将B,D原本需要$O(n^2)$的复杂度,降为了$O(n)$.
5. 总结
Transformer-XL针对其需要对segment中相对位置的token加入位置信息的特点,将vanilla transformer中的绝对位置编码方式,改进为相对位置编码。改进中涉及到位置编码矩阵的替换、query全局向量替换、以及为key的相对位置编码和embedding分别采用了不同的线性映射矩阵W。
transformer-xl与shaw et al.2018的相对编码方式亦有区别。1. shaw et al.2018的相对编码矩阵是一个需要学习参数的tensor,受限于相对距离的窗口长度设置;而transformer-xl的相对编码矩阵是一个无需参数学习的使用sinusoid表示的矩阵,可以更好的generalize到训练集中未出现长度的长序列中;2. 相比与shaw et al.2018,transformer-xl的attention score中引入了基于content的bias,和基于位置的bias。
另外在计算优化上,transformer-xl提出了一种高效计算(b)(d)矩阵运算的方法。通过构造可以在$O(n)$时间内计算的新矩阵,并将其项左移构建出目标矩阵B,D的计算方式,将时间复杂度由$O(n^2)$降为$O(n)$。
参考:
1. Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context: https://arxiv.org/pdf/1901.02860.pdf
2. Self-Attention with Relative Position Representations (shaw et al.2018): https://arxiv.org/pdf/1803.02155.pdf
3. [NLP] 相对位置编码(一) Relative Position Representatitons (RPR) - Transformer https://www.cnblogs.com/shiyublog/p/11185625.html
[支付宝] 感谢您的捐赠!
That's been one of my mantras - focus and simplicity. Simple can be harder than complex: you have to work hard to get your thinking clean to make it simple. But it's worth it in the end beacuse once you get there, you can move mountains. ~ Steve Jobs
[NLP] 相对位置编码(二) Relative Positional Encodings - Transformer-XL的更多相关文章
- [NLP] 相对位置编码(一) Relative Position Representatitons (RPR) - Transformer
对于Transformer模型的positional encoding,最初在Attention is all you need的文章中提出的是进行绝对位置编码,之后Shaw在2018年的文章中提出了 ...
- NLP+词法系列(二)︱中文分词技术简述、深度学习分词实践(CIPS2016、超多案例)
摘录自:CIPS2016 中文信息处理报告<第一章 词法和句法分析研究进展.现状及趋势>P4 CIPS2016 中文信息处理报告下载链接:http://cips-upload.bj.bce ...
- 第五课第四周实验一:Embedding_plus_Positional_encoding 嵌入向量加入位置编码
目录 变压器预处理 包 1 - 位置编码 1.1 - 位置编码可视化 1.2 - 比较位置编码 1.2.1 - 相关性 1.2.2 - 欧几里得距离 2 - 语义嵌入 2.1 - 加载预训练嵌入 2. ...
- ICCV2021 | Vision Transformer中相对位置编码的反思与改进
前言 在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...
- 中文NER的那些事儿5. Transformer相对位置编码&TENER代码实现
这一章我们主要关注transformer在序列标注任务上的应用,作为2017年后最热的模型结构之一,在序列标注任务上原生transformer的表现并不尽如人意,效果比bilstm还要差不少,这背后有 ...
- C语言基础练习——最大值及其位置(二维数组)
C语言基础练习——最大值及其位置(二维数组) 时间限制: 1 Sec 内存限制: 10 MB 题目描述 有一个n×m的矩阵,要求编程序求出: 每行元素的最大值,以及其所在的行号和列号.求出所有元素的 ...
- (Stanford CS224d) Deep Learning and NLP课程笔记(二):word2vec
本节课将开始学习Deep NLP的基础--词向量模型. 背景 word vector是一种在计算机中表达word meaning的方式.在Webster词典中,关于meaning有三种定义: the ...
- 利用Tensorflow进行自然语言处理(NLP)系列之二高级Word2Vec
本篇也同步笔者另一博客上(https://blog.csdn.net/qq_37608890/article/details/81530542) 一.概述 在上一篇中,我们介绍了Word2Vec即词向 ...
- Android应用中使用百度地图API定位自己的位置(二)
官方文档:http://developer.baidu.com/map/sdkandev-6.htm#.E7.AE.80.E4.BB.8B3 百度地图SDK为开发人员们提供了例如以下类型的地图覆盖物: ...
随机推荐
- 32个Python爬虫项目让你一次吃到撑
整理了32个Python爬虫项目.整理的原因是,爬虫入门简单快速,也非常适合新入门的小伙伴培养信心.所有链接指向GitHub,祝大家玩的愉快~O(∩_∩)O WechatSogou [1]- 微信公众 ...
- Spring Cloud微服务简介
概述 Spring Cloud给开发者提供一套按照一定套路快速开发分布式工具.它为微服务架构中涉及**配置管理,服务治理,断路器,智能路由,微代理,控制总线,全局锁,分布式会话和集群状态等操作提供了一 ...
- Python连载14-random模块&函数式编程
一.random模块 1.函数:random() (1)用法:获取0~1之间的随即小数 (2)格式:random.random() (3)返回值:随机0~1之间的小数 2.函数:choice() ( ...
- Android native进程间通信实例-binder篇之——HAL层访问JAVA层的服务
有一天在群里聊天的时候,有人提出一个问题,怎样才能做到HAL层访问JAVA层的接口?刚好我不会,所以做了一点研究. 之前的文章末尾部分说过了service call 可以用来调试系统的binder服务 ...
- Java中常用的url签名防篡改方法
实现方式:Md5(url+key) 的方式进行的. 1.key可以是任意的字符串,然后“客户端”和“服务器端”各自保留一份,千万不能外泄. 2.请求的URL 例如: name=lxl&age ...
- List集合总结,对比分析ArrayList,Vector,LinkedList
前面已经写了三篇关于Java集合的文章,包括: Java集合 ArrayList原理及使用 再说Java集合,subList之于ArrayList Java集合 LinkedList的原理及使用 关于 ...
- 【需要重新维护】Redis笔记20170811视频
很多内容都是抄的,个人记录 1.windows下初见 安装 进入目录 修改配置文件(暂时使用默认,未配置环境变量) 目录下:redis-server.exe启动服务 新建命令提示符,目录下,redis ...
- Python基础(九) 常用模块汇总
3.8 json模块重点 json模块是将满足条件的数据结构转化成特殊的字符串,并且也可以反序列化还原回去. 不同语言都遵循的一种数据转化格式,即不同语言都使用的特殊字符串.(比如Python的一个列 ...
- Atlassian In Action - (Atlassian成长之路)
Atlassian是我工作过程中,使用过的最满意的研发团队管理套装.使用的主要软件包括Jira Software,Confluence,Fisheye/Crucible.理论上还可以再加上Bitbuc ...
- python多线程爬取图片实例
今天试着把前面那个爬取图片的爬虫改成了多线程爬取,虽然最后可以爬取存储图片了,但仍存在一些问题.网址还是那个网址https://www.quanjing.com/category/1286521/1. ...