pytorch笔记:09)Attention机制
刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)
1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:
1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制
#train代码部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.输入的input经过embed
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元
2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)
similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 经过softmax获得归一化的权重
attn_weights = F.softmax(similarity, dim=1)
1
3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果
attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:
哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html
台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:
下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。
下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#这里把batch和分块数放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。
similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)∗(dk,lk)=(lq,lk)
然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)
细节部分
在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解
decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?
encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>
target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>
1
2
3
其实在pytorch官方教程中说的比较清楚,看下图
encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.
---------------------
作者:PJ-Javis
来源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版权声明:本文为博主原创文章,转载请附上博文链接!
pytorch笔记:09)Attention机制的更多相关文章
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(三)引入视觉哨兵的自适应attention机制
在此前的两篇博客中所介绍的两个论文,分别介绍了encoder-decoder框架以及引入attention之后在Image Caption任务上的应用. 这篇博客所介绍的文章所考虑的是生成captio ...
- Multimodal —— 看图说话(Image Caption)任务的论文笔记(二)引入attention机制
在上一篇博客中介绍的论文"Show and tell"所提出的NIC模型采用的是最"简单"的encoder-decoder框架,模型上没有什么新花样,使用CNN ...
- 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...
- 【学习笔记】注意力机制(Attention)
前言 这一章看啥视频都不好使,啃书就完事儿了,当然了我也没有感觉自己学的特别扎实,不过好歹是有一定的了解了 注意力机制 由于之前的卷积之类的神经网络,选取卷积中最大的那个数,实际上这种行为是没有目的的 ...
- DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)
两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...
- 论文笔记:Attention Is All You Need
Attention Is All You Need 2018-04-17 10:35:25 Paper:http://papers.nips.cc/paper/7181-attention-is-a ...
- [NLP/Attention]关于attention机制在nlp中的应用总结
原文链接: https://blog.csdn.net/qq_41058526/article/details/80578932 attention 总结 参考:注意力机制(Attention Mec ...
- Java:并发笔记-09
Java:并发笔记-09 说明:这是看了 bilibili 上 黑马程序员 的课程 java并发编程 后做的笔记 7. 共享模型之工具-2 原理:AQS 原理 对于 AQS 的原理这部分内容,没很好的 ...
- Mongodb源代码阅读笔记:Journal机制
Mongodb源代码阅读笔记:Journal机制 Mongodb源代码阅读笔记:Journal机制 涉及的文件 一些说明 PREPLOGBUFFER WRITETOJOURNAL WRITETODAT ...
随机推荐
- bzoj1483: [HNOI2009]梦幻布丁(vector+启发式合并)
1483: [HNOI2009]梦幻布丁 Time Limit: 10 Sec Memory Limit: 64 MBSubmit: 4022 Solved: 1640[Submit][Statu ...
- 通过爬虫爬取四川省公共资源交易平台上最近的招标信息 --- URLConnection
通过爬虫爬取公共资源交易平台(四川省)最近的招标信息 一:引入JSON的相关的依赖 <dependency> <groupId>net.sf.json-lib< ...
- (DP)51NOD 1183 编辑距离
编辑距离,又称Levenshtein距离(也叫做Edit Distance),是指两个字串之间,由一个转成另一个所需的最少编辑操作次数.许可的编辑操作包括将一个字符替换成另一个字符,插入一个字符,删除 ...
- POJ 1686 Lazy Math Instructor(栈)
原题目网址:http://poj.org/problem?id=1686 题目中文翻译: Description 数学教师懒得在考卷中给一个问题评分,因为这个问题中,学生会为所问的问题提出一个复杂的公 ...
- B Balala Power!
Bryce1010模板 每个字母所在位置对应权值加和,肯定存不下. 但我们只需要26个字母对应值之间的关系即可,开一个数组a[i][j]分别记录字母i在j这个位置上出现了多少次,对于大于26的值进位, ...
- Lomsat gelral cf-600e
http://codeforces.com/contest/600/problem/E 暴力启发式合并就行了 提示:set的swap的复杂度是常数,这方面可以放心 我先打了一个很naive的算法 #i ...
- 题解报告:hdu 1263 水果
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1263 Problem Description 夏天来了~~好开心啊,呵呵,好多好多水果~~ Joe经营 ...
- 比较C#中几种常见的复制字节数组方法的效率[转]
[原文链接] 在日常编程过程中,我们可能经常需要Copy各种数组,一般来说有以下几种常见的方法:Array.Copy,IList<T>.Copy,BinaryReader.ReadByte ...
- sed练习第一节
ed语法和基本命令 employee.txt文件内容如下: 101,John Doe,CEO 102,Jason Smith,IT Manager 103,Raj Reddy,Sysadmin 104 ...
- h5学习-h5嵌入android中
嵌入Android中的h5界面: 将此页面复制到android项目中的assets目录下边: <!DOCTYPE html> <html lang="en"> ...