谷歌在文章《Attention is all you need》中提出的transformer模型。如图主要架构:同样为encoder-decoder模式,左边部分是encoder,右边部分是decoder。
TensorFlow代码:https://www.github.com/kyubyong/transformer

用 sentencepiece 进行分词。

Encoder 输入

初始输入为待翻译语句的embedding矩阵,由于句子长度不一致,需要做统一长度处理,长度取maxlength1,不够长的句子padding 0值,句尾加上 </s>

d = 512, [batchsize,maxlen1,d]

考虑到词语间的相对位置信息,还要加上语句的position
encoding,由函数形式直接求出。

PE(pos,2i) = sin(pos/10002i/d)
PE(pos,2i+1) = cos(pos/10002i/d)

Padding的值不做position encoding。 [batchsize,maxlen1,d] ,最终:

encoder input = position encoding + input embedding。
encoder input : [batchsize,maxlen1,d]

Encoder

Encoder 由N = 6个相同的layer连接组成。每个layer中有两个sublayer,分别是multihead
self-attention以及FFN。

Q = K = V = input
MultiHead(Q, K, V) = concat(head1, …, headh)Wo
headi = Attention(QW­iQ,KW­ik,VW­iV)
Attention(Q, K, V) = softmax(QKT/$$sqrt{d}$$) V



softmax前要做key_mask,把pad 0 的地方赋值为-inf,softmax后权重做query mask,赋值0。

h = 8
W­iQ, W­ik, W­iV : [d, d/h]
Q : [maxlen_q, d]
K = V : [maxlen_k, d]
Maxlen_q = maxlen_k so: Q = K = V : [maxlen1, d]
QW­kQ,KW­ik,VW­iV : [maxlen1, d/h]
headi : [maxlen1, d/h] * [d/h, maxlen1] * [maxlen1, d/h] = [maxlen1, d/h]
Wo : [d, d]
MultiHead(Q,K,V): [maxlen, d]

Softmax([maxlen_q, maxlen_k]) 在最后一个维度即 maxlen_k 上做 softmax
position-wise是因为处理的attention输出是某一个位置i的attention输出。

FFN(x) = ReLU ( xW1 + b1 ) * W2 + b2
ReLU(x) = max( 0, x )
dff = 4 * d = 2048
W1 : [d, dff]
W2 : [dff, d]

流程:

Input -> dropout ->
(
multihead self-attention -> dropout -> residual connection -> LN ->
FFN-> dropout -> residual connection -> LN ->
) * 6
-> memory [batchsize,maxlen,d]

代码中在multihead attention中对score做dropout,FFN后没有dropout,但文章说每个sublayer的output都有一个dropout。

大专栏  Transformer详解:各个特征维度分析推导"#Decoder-输入" class="headerlink" title="Decoder 输入">Decoder 输入

训练

目标句子首尾分别加上 <s> , </s>

Decoder input = Output embedding + position encoding
Decoder input : [batchsize,maxlen2,d]

预测

初始向量为<s>对应embedding,之后将前一步的输出拼接到当前的所有预测构成当前的decoder输入。

Decoder

Decoder由N = 6 个相同的layer组成,每个layer中有三个sublayer,分别是multihead self-attention, mutihead attention以及FFN。

decoder input -> dropout ->
(
Masked multihead self-attention(dec, dec, dec) = dec-> dropout ->
multihead attention(dec, memory, memory) -> dropout -> residual connection
-> LN -> FFN -> dropout -> residual connection -> LN ->
) * 6
-> dec -> linear -> softmax

Self-attention 的mask为一个和dec相同维度的上三角全为-inf的矩阵。

Linear( x ) = xW
Dec : [batchsize,maxlen2,d]
W : [d, vocabsize]

W为词汇表embedding矩阵的转置, 输入输出的词汇表embedding矩阵为W。即三个参数共享。

Linear( x ) : [batchsize,maxlen2,vocabsize]

Softmax函数:

$pleft( k|x right)=frac{exp({{z}_{k}})}{sumnolimits_{i=1}^{K}{exp ({{z}_{i}})}}$

其中zi一般叫做 logits,即未被归一化的对数概率。

损失函数

损失函数:cross entropy。用p代表predicted probability,用q代表groundtruth。即:

$cross_entropy_loss=sumlimits_{k=1}^{K}{qleft( k|xright)log (pleft( k|x right))}$

groundtruth为one-hot,即每个样本只有惟一的类别,$q(k)={{delta}_{k,y}}$,y是真实类别。

${{delta }_{k,y}}text{=}left{begin{matrix} 1,k=y \0,kne y \end{matrix} right.$

对目标句子onehot 做labelmsmooth用$tilde{q}(k|x)$代替$q(k|x)$。(为了正则化,防止过拟合)

$tilde{q}(k|x)=(1-varepsilon ){{delta }_{k,y}}+varepsilon u(k)$

可以理解为,对于$q(k)={{delta}_{k,y}}$函数分布的真实标签,将它变成以如下方式获得:首先从标注的真实标签的$delta$分布中取定,然后以一定的概率$varepsilon$,将其替换为在$u(k)$分布中的随机变量。$u(k)$为均匀分布,即$u(k)=1/K$

优化方法

Adam优化器:

学习率使用warm up learning rate:

learningrate = dmodel-0.5 * min ( step_num-0.5, step_num * warmup_steps-1.5 )
warmup_steps :4000

Transformer详解:各个特征维度分析推导的更多相关文章

  1. Android应用AsyncTask处理机制详解及源码分析

    1 背景 Android异步处理机制一直都是Android的一个核心,也是应用工程师面试的一个知识点.前面我们分析了Handler异步机制原理(不了解的可以阅读我的<Android异步消息处理机 ...

  2. Java SPI机制实战详解及源码分析

    背景介绍 提起SPI机制,可能很多人不太熟悉,它是由JDK直接提供的,全称为:Service Provider Interface.而在平时的使用过程中也很少遇到,但如果你阅读一些框架的源码时,会发现 ...

  3. Spring Boot启动命令参数详解及源码分析

    使用过Spring Boot,我们都知道通过java -jar可以快速启动Spring Boot项目.同时,也可以通过在执行jar -jar时传递参数来进行配置.本文带大家系统的了解一下Spring ...

  4. 【转载】Android应用AsyncTask处理机制详解及源码分析

    [工匠若水 http://blog.csdn.net/yanbober 转载烦请注明出处,尊重分享成果] 1 背景 Android异步处理机制一直都是Android的一个核心,也是应用工程师面试的一个 ...

  5. 线程池底层原理详解与源码分析(补充部分---ScheduledThreadPoolExecutor类分析)

    [1]前言 本篇幅是对 线程池底层原理详解与源码分析  的补充,默认你已经看完了上一篇对ThreadPoolExecutor类有了足够的了解. [2]ScheduledThreadPoolExecut ...

  6. Attention和Transformer详解

    目录 Transformer引入 Encoder 详解 输入部分 Embedding 位置嵌入 注意力机制 人类的注意力机制 Attention 计算 多头 Attention 计算 残差及其作用 B ...

  7. SpringMVC异常处理机制详解[附带源码分析]

    目录 前言 重要接口和类介绍 HandlerExceptionResolver接口 AbstractHandlerExceptionResolver抽象类 AbstractHandlerMethodE ...

  8. Linux 链接详解----静态链接实例分析

    由Linux链接详解(1)中我们简单的分析了静态库的引用解析和重定位的内容, 下面我们结合实例来看一下静态链接重定位过程. /* * a.c */ ; void add(int c); int mai ...

  9. HTTP协议详解之http请求分析

    当今web程序的开发技术真是百家争鸣,ASP.NET, PHP, JSP,Perl, AJAX 等等. 无论Web技术在未来如何发展,理解Web程序之间通信的基本协议相当重要, 因为它让我们理解了We ...

随机推荐

  1. 爬虫之xpath解析库

    xpath语法: 1.常用规则:    1.  nodename:  节点名定位    2.  //:  从当前节点选取子孙节点    3.  /:  从当前节点选取直接子节点    4.  node ...

  2. 3.redis认证

    redis认证方法 1.redis.conf requirepass PASSWORD 2.redis-cli auth PASSWORD redis清空数据库 flushdb //清空当前数据库 f ...

  3. 黑马_10 Lucene:全文检索

    10 Lucene:01.全文检索基本介绍 10 Lucene:02.创建索引库和查询索引 10 Lucene:03.中文分析器 10 Lucene:04.索引库维护CURD

  4. flask汇总

    flask框架 蓝图 随着flask程序越来越复杂,我们需要对程序进行模块化的处理,之前学习过python的模块化管理,于是针对一个简单的flask程序进行模块化处理 Blueprint概念 简单来说 ...

  5. 通过OAuth2.0 获取授权访问SF 用户数据

    站长资讯: 创建应用程序 新建应用程序   访问示例(Python+django) 环境准备: index.html 两种方式: 方式一:采用由用户授权,调用者无需知道SF的用户名与密码 方式二:直接 ...

  6. JavaSE--[转]加密和签名的区别

    转载 http://blog.csdn.net/u012467492/article/details/52034835 私钥用来签名的,公钥用来验签的.公钥加密私钥解密是秘送,私钥加密公钥解密是签名 ...

  7. 注册服务和发现服务 Eureka

    来自蚂蚁课堂: 注册服务和发现服务 1.原理如图: 注册中心负载均衡: 实践 注册中心 集群:

  8. 浙江省赛 ZOJ - 4033

    题意: 第一行给出T代表有几个测试样例, 第二行给出n代表有几个人, 第三行给出一个由0和1组成的字符串,0代表女生,1代表男生. 并且第i个人有i个宝石. 现在要把这些人分为四组,G1 G2 两组是 ...

  9. 2019深圳Android千人开发者大会【NEW·无界】

    报名地址:https://www.hdb.com/dis/mjcsegnslu 安卓巴士技术社区是中国领先的安卓开发者社区,现已聚集超过85万开发者,数年来一直致力于IT从业者的知识分享服务. 安卓巴 ...

  10. iOS中代理属性为什么要用Weak修饰?

    一.写在前面 代理设计模式,在iOS开发过程中,是一个非常常见的设计模式,可以说用的范围非常广泛,而对初学者来讲,常常对代理的属性修饰用weak存在疑惑,因此下面就解释一下其中非常简单的道理. 二.必 ...