attention 介绍
前言
这里学习的注意力模型是我在研究image caption过程中的出来的经验总结,其实这个注意力模型理解起来并不难,但是国内的博文写的都很不详细或说很不明确,我在看了 attention-mechanism后才完全明白。得以进行后续工作。
这里的注意力模型是论文 Show,Attend and Tell:Neural Image Caption Generation with Visual Attention里设计的,但是注意力模型在大体上来讲都是相通的。
先给大家介绍一下我需要注意力模型的背景。
I是图片信息矩阵也就是[224,224,3],通过前面的cnn也就是所谓的sequence-sequence模型中的encoder,我用的是vgg19,得到a,这里的a其实是[14*14,512]=[196,512],很形象吧,代表的是图片被分成了这么多个区域,后面就看我们单词注意在哪个区域了,大家可以先这么泛泛理解。通过了本文要讲的Attention之后得到z。这个z是一个区域概率,也就是当前的单词在哪个图像区域的概率最大。然后z组合单词的embedding去训练。
好了,先这么大概理解一下这张图就好。下面我们来详细解剖attention,附有代码~
attention的内部结构是什么?
这里的c其实一个隐含输入,计算方式如下
首先我们这么个函数:
def _get_initial_lstm(self, features):
with tf.variable_scope('initial_lstm'):
features_mean = tf.reduce_mean(features, 1)
w_h = tf.get_variable('w_h', [self.D, self.H], initializer=self.weight_initializer)
b_h = tf.get_variable('b_h', [self.H], initializer=self.const_initializer)
h = tf.nn.tanh(tf.matmul(features_mean, w_h) + b_h)
w_c = tf.get_variable('w_c', [self.D, self.H], initializer=self.weight_initializer)
b_c = tf.get_variable('b_c', [self.H], initializer=self.const_initializer)
c = tf.nn.tanh(tf.matmul(features_mean, w_c) + b_c)
return c, h
上面的c你可以暂时不用管,是lstm中的memory state,输入feature就是通过cnn跑出来的a,我们暂时考虑batch=1,就认为这个a是一张图片生成的。所以a的维度是[1,196,512]
y向量代表的就是feature。
下面我们打开这个黑盒子来看看里面到底是在做什么处理。
上图中可以看到
这里的tanh不能替换成ReLU函数,一旦替换成ReLU函数,因为有很多负值就会消失,会很影响后面的结果,会造成最后Inference句子时,不管你输入什么图片矩阵的到的句子都是一样的。不能随便用激活函数!!!ReLU是能解决梯度消散问题,但是在这里我们需要负值信息,所以只能用tanh
c和y在输入到tanh之前要做个全连接,代码如下。
w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer)
b = tf.get_variable('b', [self.D], initializer=self.const_initializer)
w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer)
h_att = tf.nn.relu(features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b) # (N, L, D)
这里的features_proj是feature已经做了全连接后的矩阵。并且在上面计算h_att中你可以看到一个矩阵的传播机制,也就是relu函数里的加法。features_proj和后面的那个维度是不一样的。
def _project_features(self, features):
with tf.variable_scope('project_features'):
w = tf.get_variable('w', [self.D, self.D], initializer=self.weight_initializer)
features_flat = tf.reshape(features, [-1, self.D])
features_proj = tf.matmul(features_flat, w)
features_proj = tf.reshape(features_proj, [-1, self.L, self.D])
return features_proj
然后要做softmax了,这里有个点,因为上面得到的m的维度是[1,196,512],1是代表batch数量。经过softmax后想要得到的是维度为[1,196]的矩阵也就是每个区域的注意力权值。所以
out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L)
alpha = tf.nn.softmax(out_att)
最后计算s就是一个相乘。
context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') #(N, D)
这里也是有个传播的机制,features维度[1,196,512],后面那个维度[1,196,1]。
最后给个完整的注意力模型代码。
def _attention_layer(self, features, features_proj, h, reuse=False):
with tf.variable_scope('attention_layer', reuse=reuse):
w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer)
b = tf.get_variable('b', [self.D], initializer=self.const_initializer)
w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer)
h_att = tf.nn.relu(features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b) # (N, L, D)
out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L]) # (N, L)
alpha = tf.nn.softmax(out_att)
context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context') #(N, D)
return context, alpha
如果大家想研究整个完整的show-attend-tell模型,可以去看看github链接
以上我们讲的是soft_attention,还有一个hard_attention。hard_attention比较不适合于反向传播,其原理是取代了我们之前将softmax后的所有结果相加,使用采样其中一个作为z。反向传播的梯度就不好算了,这里用蒙特卡洛采样方式。
ok,回到我们的image_caption中,看下图
这个图其实不太准确,每一个z其实还会用tf.concat连接上当前这个lstm_cell的单词embedding输入。也就是维度变成[512]+[512]=[1024]
这样就可以结合当前单词和图像区域的关系了,我觉得注意力模型还是很巧妙的。
https://segmentfault.com/a/1190000011744246
attention 介绍的更多相关文章
- 6. 从Encoder-Decoder(Seq2Seq)理解Attention的本质
1. 语言模型 2. Attention Is All You Need(Transformer)算法原理解析 3. ELMo算法原理解析 4. OpenAI GPT算法原理解析 5. BERT算法原 ...
- 机器翻译注意力机制及其PyTorch实现
前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译 Effective Approaches to Attention-based Neural Machine Translat ...
- Attention注意力机制介绍
什么是Attention机制 Attention机制通俗的讲就是把注意力集中放在重要的点上,而忽略其他不重要的因素.其中重要程度的判断取决于应用场景,拿个现实生活中的例子,比如1000个人眼中有100 ...
- 模型汇总24 - 深度学习中Attention Mechanism详细介绍:原理、分类及应用
模型汇总24 - 深度学习中Attention Mechanism详细介绍:原理.分类及应用 lqfarmer 深度学习研究员.欢迎扫描头像二维码,获取更多精彩内容. 946 人赞同了该文章 Atte ...
- Seq2Seq和Attention机制入门介绍
1.Sequence Generation 1.1.引入 在循环神经网络(RNN)入门详细介绍一文中,我们简单介绍了Seq2Seq,我们在这里展开一下 一个句子是由 characters(字) 或 w ...
- 关于ArcGIS API for JavaScript中basemap的总结介绍(一)
实际上basemap这个概念并不只在arcgis中才有,在Python中有一个matplotlib basemap toolkit(https://pypi.python.org/pypi/basem ...
- (转)注意力机制(Attention Mechanism)在自然语言处理中的应用
注意力机制(Attention Mechanism)在自然语言处理中的应用 本文转自:http://www.cnblogs.com/robert-dlut/p/5952032.html 近年来,深度 ...
- 论文笔记之:Deep Attention Recurrent Q-Network
Deep Attention Recurrent Q-Network 5vision groups 摘要:本文将 DQN 引入了 Attention 机制,使得学习更具有方向性和指导性.(前段时间做 ...
- 注意力机制(Attention Mechanism)在自然语言处理中的应用
注意力机制(Attention Mechanism)在自然语言处理中的应用 近年来,深度学习的研究越来越深入,在各个领域也都获得了不少突破性的进展.基于注意力(attention)机制的神经网络成为了 ...
随机推荐
- 算法练习LeetCode初级算法之数学
Fizz Buzz class Solution { public List<String> fizzBuzz(int n) { List<String> list=new L ...
- 微信小程序记账本进度七
最后大体上完成了,但是好像少了点功能,整体并不是特别华丽
- VS2017用正则表达式替换多行代码
await IndexManyAsync\(item.Value, item.Key, "doc"\);\r\n.*\}.*\r\n.*\} 上面的代码,匹配的是下面的代码 awa ...
- angular2监听页面大小变化
一.现象 全屏页面中的图表,在很多的时候需要 resize 一把,以适应页面的大小变化 二.解决 1.引入 : import { Observable } from 'rxjs'; 2.使用(在ngO ...
- linux(centos) tomcat设置开机启动
亲测有效 环境: centos7 apache-tomcat-8.5.37 设置步骤: 1.修改/etc/rc.d/rc.local vi /etc/rc.d/rc.local 2.添加下面两行脚本, ...
- 部落划分Group[JSOI2010]
--BZOJ1821 Description 聪聪研究发现,荒岛野人总是过着群居的生活,但是,并不是整个荒岛上的所有野人都属于同一个部落,野人们总是拉帮结派形成属于自己的部落,不同的部落之间则经常发生 ...
- Movavi Video Editor 15 Plus Mac怎样更改视频的分辨率?
使用Movavi Video Editor 15您可以对视频进行切割和修剪,裁剪和旋转,色度键,视频稳定以及画中画等很多的编辑,该软件操作简单,就算是新手也无需担心操作问题,本文讲述了Movavi V ...
- 【MySQL】初识数据库及简单操作
一.数据库概述 1.1 什么是数据(Data) 描述事物的符号记录称为数据,描述事物的符号既可以是数字,也可以是文字.图片,图像.声音.语言等,数据由多种表现形式,它们都可以经过数字化后存入计算机. ...
- SQL 语句中 where 条件后 写上1=1 的意思
这段代码应该是由程序(例如Java)中生成的,where条件中 1=1 之后的条件是通过 if 块动态变化的.例如: String sql="select * from table_nam ...
- (摘录)Java 详解 JVM 工作原理和流程
作为一名Java使用者,掌握JVM的体系结构也是必须的. 说起Java,人们首先想到的是Java编程语言,然而事实上,Java是一种技术,它由四方面组成:Java编程语言.Java类文件格式.Java ...