pytorch笔记:09)Attention机制
刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)
1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:
1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制
#train代码部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.输入的input经过embed
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元
2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)
similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 经过softmax获得归一化的权重
attn_weights = F.softmax(similarity, dim=1)
1
3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果
attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:
哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html
台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:
下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。
下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#这里把batch和分块数放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。
similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)∗(dk,lk)=(lq,lk)
然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)
细节部分
在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解
decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?
encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>
target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>
1
2
3
其实在pytorch官方教程中说的比较清楚,看下图
encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.
---------------------
作者:PJ-Javis
来源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版权声明:本文为博主原创文章,转载请附上博文链接!
pytorch笔记:09)Attention机制的更多相关文章
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(三)引入视觉哨兵的自适应attention机制
在此前的两篇博客中所介绍的两个论文,分别介绍了encoder-decoder框架以及引入attention之后在Image Caption任务上的应用. 这篇博客所介绍的文章所考虑的是生成captio ...
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(二)引入attention机制
在上一篇博客中介绍的论文"Show and tell"所提出的NIC模型采用的是最"简单"的encoder-decoder框架,模型上没有什么新花样,使用CNN ...
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- 【学习笔记】注意力机制(Attention)
前言 这一章看啥视频都不好使,啃书就完事儿了,当然了我也没有感觉自己学的特别扎实,不过好歹是有一定的了解了 注意力机制 由于之前的卷积之类的神经网络,选取卷积中最大的那个数,实际上这种行为是没有目的的 ...
- DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)
两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...
- 论文笔记:Attention Is All You Need
Attention Is All You Need 2018-04-17 10:35:25 Paper:http://papers.nips.cc/paper/7181-attention-is-a ...
- [NLP/Attention]关于attention机制在nlp中的应用总结
原文链接: https://blog.csdn.net/qq_41058526/article/details/80578932 attention 总结 参考:注意力机制(Attention Mec ...
- Java:并发笔记-09
Java:并发笔记-09 说明:这是看了 bilibili 上 黑马程序员 的课程 java并发编程 后做的笔记 7. 共享模型之工具-2 原理:AQS 原理 对于 AQS 的原理这部分内容,没很好的 ...
- Mongodb源代码阅读笔记:Journal机制
Mongodb源代码阅读笔记:Journal机制 Mongodb源代码阅读笔记:Journal机制 涉及的文件 一些说明 PREPLOGBUFFER WRITETOJOURNAL WRITETODAT ...
随机推荐
- MySQL-业务优化——说的就是变
前言 通过上次发布的业务优化不是一步到位的有不少网友问我许多关于业务优化和Web方面的问题.在这里表示感谢和支持.在期间有些回答不到位的还请谅解,并且个人经验有限. 百牛信息技术bainiu.ltd整 ...
- Nvidia GPU 算力查询
GPU Compute Capability Tesla K80 3.7 Tesla K40 3.5 Tesla K20 3.5 Tesla C2075 2.0 Tesla C2050/C2070 2 ...
- MVC视图中下拉框的使用
一.一般变量或对象的绑定 首先要在controller 中将选项设置成 selecList对象,并赋值给viewBag动态对象. public ActionResult Index(string mo ...
- JavaScript 算法与数据结构(转载)
JavaScript 算法与数据结构 https://github.com/trekhleb/javascript-algorithms/blob/master/README.zh-CN.md
- bzoj 2151: 种树【贪心+堆】
和数据备份差不多 设二元组(i,a[i]),开一个大根堆把二元组塞进去,以len排序,每次取出一个二元组 因为单纯的贪心是不行的,所以设计一个"反悔"操作. 记录二元组的前驱pr后 ...
- vultr 购买vps
基本安装转自:https://github.com/uxh/shadowsocks_bash/wiki/Vultr%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B 连接 Vul ...
- _bzoj2818 Gcd【线性筛法 欧拉函数】
传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=2818 若gcd(x, y) = 1,则gcd(x * n, y * n) = n.那么,当y ...
- 转 11g RAC R2 体系结构---Grid
基于agent的管理方式 从oracle 11.2开始出现了多用户的概念,oracle开始使用一组多线程的daemon来同时支持多个用户的使用.管理资源,这些daemon叫做Agent.这些Agent ...
- 456 132 Pattern 132模式
给定一个整数序列:a1, a2, ..., an,一个132模式的子序列 ai, aj, ak 被定义为:当 i < j < k 时,ai < ak < aj.设计一个算法,当 ...
- 转】用Mahout构建职位推荐引擎
原博文出自于: http://blog.fens.me/hadoop-mahout-recommend-job/ 感谢! 用Mahout构建职位推荐引擎 Hadoop家族系列文章,主要介绍Hadoop ...