【译】在Transformer中加入相对位置信息
本文翻译自How Self-Attention with Relative Position Representations works, 介绍 Google的研究成果。
引言
本文基于Shaw 等人发表的论文 《Self-Attention with Relative Position Representations》 展开。论文介绍了一种在一个Transformer内部编码输入序列的位置信息的方法。特别的是,论文改进了Tranformer的自注意力机制,让其能够更有效地将序列中的词之间的相对距离考虑进来。
本文旨在用易于理解的语言解释论文中的要点。读懂本文的前提是对 Recurrent Neural Networks(RNNs) 和Transformers 中的多头注意力机制(multi-head self-attention mechanism)有基本的了解。
动机
利用隐状态hidden state,RNN能够让模型隐式地编码序列的顺序信息。例如,下图展示了RNN输出输入序列“I think therefore I am” 中每一个词的向量表示。

第二个“I”的输出不同于第一个“I”的输出,因为他们隐状态的输入是不一样的。对第二个“I”而言,隐状态经过了 “I think therefore”三个词,而第一个“I” 的隐状态仅是一个初始值。因此,RNN的隐状态保证了在不同位置上的相同的词会有不同的输出向量表示。
相比之下,Transformer的自注意力层(不带位置表示)对不同位置出现的相同词给出的是同样的输出向量表示。例如:

上图展示了输入序列“I think therefore I am”送入Transformer的过程。 为了方便阅读,仅仅画出两个“I”的输出。注意,尽管两个“I”在不同输入序列的不同位置上,他们对应的输出向量表示还是相同的。
解决方案
概览
作者提出的方法是,在Transformer中加入一组可训练的嵌入表示,从而让输出带有一定的顺序信息。这一嵌入表示在计算第i个词和第j个词之间的注意力权重和注意力值的时候会用到。他们代表了第i个词和第j个词之间的距离(间隔多少个词),因此将这种方法称为相对位置表示(RPR)。
例如,一个句子由五个词,一共会有9个嵌入表示需要学习(一个是当前词的嵌入,有4个是上文4个词的嵌入,4个是下文4个词的嵌入。译者注:k=4)。9个嵌入如下所示:

下图清晰地展示了如何使用这些嵌入:

上图描绘了第一个“I”的输出表示的计算过程。箭头旁的数字表示在计算注意力的时候使用的是哪一种相对位置表示。例如,当Transformer正在计算“I”和“therefore”之间的注意力时,它会利用包含在第6个RPR中的信息,因为“therefore” 是第一个“I”右边的第2个词。(译者注:因为k设置为4,因此词i到词i的距离对应index4,词i到词i+1的距离对应index5,词i到词i+2的距离对应index6,以此类推)
下图描绘了第二个“I”的输出表示的计算过程。

但是,每个词的RPR又是不一样的。例如,第3个RPR是用来计算 “I”和“therefore” 之间的注意力的,因为“therefore”是第二个“I”的左边的第一个词。这就是RPR帮助Transformer编码输入序列的顺序信息。
注释
下面的符号注释在本文后面的阐述中会用到。

注意,这其中共有两组RPR嵌入需要学习:一个用于计算词i的输出表示zᵢ,另一个用于计算词i到词j的权重系数eᵢⱼ。不同于投影矩阵,这些嵌入在注意力头间是共享的。
另一个值得注意的关键点是,需要考虑的词间距离的最大值被限制在一个常数k。这意味着,需要学习的RPR嵌入的数量是2k+1(上文k个词,下文k个词以及当前词)。向右间隔词i超过k个词的词对应第2k个RPR, 向左间隔词i超过k个词的词对应第0个RPR。例如,一个有10个词的输入序列,k设为3,那么RPR嵌入的lookup表如下:

按照这种设计,行i对应第i个词,列j代表第j个词。索引号3对应第i个词,索引号6对应第i个词右边第3个以及更右的词,索引号0对应第i个词左边第3个以及更左的词。第1个词(第1行)的嵌入表示的通过查表可得。注意,从第i个词右边第3个词起的所有词的索引号都是6。这意味着即使输入序列的第一个词和最后一个词之间的距离是9,最后一个词使用的RPR嵌入也与右边第3个词的RPR嵌入相同。
这么设计有两个原因:
- 作者假定在一定距离之外,再精确的相对位置信息也是没有用的。
- 限制住最长距离能够提升模型对未在训练阶段出现过的长度的序列的泛化能力。
实现
下面的等式展示了在没有使用RPR嵌入的情况下,计算 zᵢ 的过程:

引入RPR嵌入后的式子 (1)变成了:

式子 (2)变成了 :

总而言之,式子3是当要计算词i的输出表示时,我们对相对词j的value向量的权重的计算进行了改进,方法就是将相对于词j的value向量加上词i和词j之间的RPR嵌入。同理,式子4告诉我们,如何改进词i和词j之间的缩放的点积操作,就是通过将相对于词j的key向量加上词i和词j之间的RPR嵌入。根据作者的描述,使用加法作为一种将RPR嵌入整合进来的方法让算法实现更高效,本文后面会继续介绍。
高效实现
Transformer的输入是一个大小为 (batch_size, seq_length, embedding_dim)的张量。在不带RPR嵌入的情况下,Transformer能够利用batch_size * h 并行地进行矩阵乘法来计算 eᵢⱼ (式子2) 。每一次矩阵乘法都会计算给定输入序列和注意力头中所有的元素的eᵢⱼ 。这个过程使用下面的表达式实现的:

X是给定输入序列中所有元素按行拼接起来的矩阵。
为了在加入了RPR嵌入之后也能有相近的计算效率(时间上和空间上),我们首先使用了矩阵乘法的性质将式子(4)重写为:

分子的左半部分和式子 (2)相同,因此在矩阵乘法中能够高效运算。右半部分就有点技巧性了。这部分代码实现定义在函数 relative_attention_inner 中,因此我会较简单地把大体逻辑介绍一下。
- 分子左半部分的大小为 (batch_size, h, seq_length, seq_length)。这个张量的行i列j上的元素代表了词i的query向量和词j的key向量的点积的结果 。因此,我们的目标是产生另一个和这个张量大小相同的张量,而这个张量的各个元素应该是词i的query向量和词i与词j之间的RPR嵌入的点积的结果(译者注:也就是分子右半部分)。
- 首先,我们使用查表的形式为一个给定的输入序列生成RPR嵌入张量A,A的形状是(seq_length, seq_length, dₐ)。然后,我们对A进行转置,使它的形状变成 (seq_length, dₐ , seq_length) ,写成 Aᵀ。
- 接下来,我们计算输入序列所有元素的query向量,得到一个 (batch_size, h, seq_length, dz)形状的张量。然后对其进行转置,形状变为 (seq_length, batch_size, h, dz) ,然后变形为 (seq_length, batch_size * h, dz)的张量。这个张量现在就能与 Aᵀ相乘了。这个乘法可以视为矩阵 (batch_size * h, dz) 和矩阵 (dₐ, seq_length)的乘法。基本上就是计算每个位置的query向量和对应的RPR嵌入的点积。
- 上面的乘法得到一个形状为 (seq_length, batch_size * h, seq_length)的张量。我们只需要将其变形为(seq_length, batch_size, h, seq_length)的形状,然后再转置得到形状为 (batch_size, h, seq_length, seq_length) 的张量,这样我们就能将它和分子左半部分进行相加了。
同样的逻辑也用在式子 (3)的计算中。
结果
作者在与Vaswani 等人发表的论文《Attention is All You Need》 中相同的机器翻译任务上评价他的改进方法的对翻译效果的影响。尽管每秒钟的训练步数下降了7个百分点,其模型在英译德任务上的BLEU还是提高了1.3,在英译法上提高了0.5。
结论
在本文中,笔者解释了为什么Transformer中的自注意力机制无法编码输入序列的位置信息,以及Shaw 等人相对位置表示嵌入(RPR)如何解决这一问题。笔者希望本文能帮助你更好的理解Shaw的文章。
参考文献
Self-Attention with Relative Position Representations; Shaw et al. 2018.
Attention Is All You Need; Vaswani et al., 2017.
Transformer: A Novel Neural Network Architecture for Language Understanding; Uszkoreit. 2017
The Unreasonable Effectiveness of Recurrent Neural Networks; Karpathy. 2015.
【译】在Transformer中加入相对位置信息的更多相关文章
- android黑科技系列——微信定位聊天记录中照片的位置信息插件开发详解
一.前言 最近关于微信中,朋友之间发送原图就可能暴露你的位置信息,其实这个问题不在于微信,微信是为了更好的体验效果,才有发送原图功能,而对于拍照,发送普通图片微信后台都会过滤图片的exif信息,这样就 ...
- windowsphone中获取手机位置信息
首先在界面中加入一个textblock控件以显示信息 using System; using System.Collections.Generic; using System.IO; using Sy ...
- 【转载】C#通过IndexOf方法获取某一列在DataTable中的索引位置
在C#中的Datatable数据变量的操作过程中,有时候需要知道某一个列名在DataTable中的索引位置信息,此时可以通过DataTable变量的Columns属性来获取到所有的列信息,然后通过Co ...
- 获取用户当前位置信息的两种方法——H5、微信
在之前的 调用百度地图API的总结 中获取当前位置信息我用的是 H5 ,其实微信也提供了获取用户地理位置的方法,现将这两种方法都贴出来,看情况选择使用. 一.H5 获取当前地理位置得到经纬度 // H ...
- ICCV2021 | Vision Transformer中相对位置编码的反思与改进
前言 在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...
- 使用NPOI从Excel中提取图片及图片位置信息
问题背景: 话说,在ExcelReport的开发过程中,有一个比较棘手的问题:怎么复制图片呢? 当然,解决这个问题的第一步是:能使用NPOI提取到图片及图片的位置信息.到这里,一切想法都很顺利.但NP ...
- JavaScript基础(3)-JS中的面向对象、定时器、BOM、位置信息
一.创建对象的几种常用方式. 1.使用Object或对象字面量创建对象: a.使用Object()内置的构造函数来创建对象,例如: var student = new Object(); // 创建一 ...
- jquery获取元素在文档中的位置信息以及滚动条位置(转)
jquery获取元素在文档中的位置信息以及滚动条位置 http://blog.csdn.net/qq_34095777/article/details/78750886 原文链接 原创 201 ...
- C#实现如何判断一个数组中是否有重复的元素 返回一个数组升序排列后的位置信息--C#程序举例 求生欲很强的数据库 别跟我谈EF抵抗并发,敢问你到底会不会用EntityFramework
C#实现如何判断一个数组中是否有重复的元素 如何判断一个数组中是否有重复的元素 实现判断数组中是否包含有重复的元素方法 这里用C#代码给出实例 方法一:可以新建一个hashtable利用hasht ...
随机推荐
- Django学习笔记之视图高级-HTTP请求与响应
Django限制请求method 常用的请求method GET请求 GET请求一般用来向服务器索取数据,但不会向服务器提交数据,不会对服务器的状态进行更改.比如向服务器获取某篇文章的详情. POST ...
- 关闭linux服务器防火墙
--全部关闭 systemctl stop firewalld.service #停止firewallsystemctl disable firewalld.service #禁止firewal ...
- nginx配置备份
server { listen 80; server_name localhost; set $expires_duration "30d"; if ($uri ~* \.html ...
- 【Jmeter】Address already in use : connect &&Permission denied: connect 解决方案
Address already in use : connect 该问题的原因为: Windows 提供给 TCP/IP链接的端口为 1024-5000,并且要四分钟来循环回收他们.就导致我们在短 ...
- [例子]Ubuntu虚拟机设置固定IP上网
宿主机器 win7 linux Ubuntu 14.04 LTS 参考: Linux系列:Ubuntu虚拟机设置固定IP上网(配置IP.网关.DNS.防止resolv.c ...
- Javascript神器之webstorm
推荐个编辑器主题下载的一个网站. Color Themes 网址:http://color-themes.com [点这里直接跳转] 但是,只支持几个编辑器. 各种颜色搭配的主题,随你选择!我个 ...
- enable-ssh-key-logon-disable-password-password-less-logon-centos/
cat ~/.ssh/id_rsa.pub | ssh root@destination_server_address "cat >> ~/.ssh/authorized_key ...
- mysql数据库目录my.ini的内容
[mysql] # 设置mysql客户端默认字符集 default-character-set=utf8 [mysqld] #设置3306端口 port = 3306 # 设置mysql的 ...
- Webpack 模块处理
webpack模块处理 1. ES6 静态Import ES6的import会被转化为commonjs格式或者是AMD格式,babel默认会把ES6的模块转化为commonjs规范的. import ...
- 去中心化存储项目终极指南 | Filecoin, Storj 和 PPIO 项目异同
Filecoin,Storj 以及 PPIO 这三个存储公链的设计思路是不一样的,没有优劣之分,写这篇文章也并不是为了争论各项目的好坏对错.去中心化存储是一个长期商业赛道,不同团队在同一个赛道上往不同 ...