从rnn到lstm,再到seq2seq(二)
从图上可以看出来,decode的过程其实都是从encode的最后一个隐层开始的,如果encode输入过长的话,会丢失很多信息,所以设计了attation机制。
attation机制的decode的过程和原来的最大的区别就是,它输出的不只是基于本时刻的h,而是基于本时刻的h和C的concat矩阵。
那么C是什么,C就是encode的h的联合(见最后一张图的公式),含义非常明显了,就是我在decode的时候,不但考虑我现在decode的隐层的情况,同时也考虑到encode的隐层的情况,那么关键是encode的隐层那么多,你该怎么考虑了,这就是attation矩阵的计算方式。。目前的计算方式是,这个时刻decode的隐层和encode的所有隐层做个对应,最后一张图非常明白
如果你还没有理解,看这个公式,输入的d‘t就是我上面说的C,把这个和dt concat就是本时刻输出的隐层
其实实现起来不复杂,就是在decode的时候,隐层和encode的隐层对应一下,然后concat一下:
下面这个代码是在github上找的,两个隐层对应的方式可能跟上面说的不一样,但是原理都差不多,看这个代码感受一下这个流程。
s = self.encoder.zero_state(self.batch_size, tf.float32)
encoder_hs = []
with tf.variable_scope("encoder"):
for t in xrange(self.max_size):
if t > 0: tf.get_variable_scope().reuse_variables()
x = tf.squeeze(source_xs[t], [1])
x = tf.matmul(x, self.s_proj_W) + self.s_proj_b
h, s = self.encoder(x, s)
encoder_hs.append(h)
encoder_hs = tf.pack(encoder_hs)
s = self.decoder.zero_state(self.batch_size, tf.float32)
logits = []
probs = []
with tf.variable_scope("decoder"):
for t in xrange(self.max_size):
if t > 0: tf.get_variable_scope().reuse_variables()
if not self.is_test or t == 0:
x = tf.squeeze(target_xs[t], [1])
x = tf.matmul(x, self.t_proj_W) + self.t_proj_b
h_t, s = self.decoder(x, s)
h_tld = self.attention(h_t, encoder_hs) oemb = tf.matmul(h_tld, self.proj_W) + self.proj_b
logit = tf.matmul(oemb, self.proj_Wo) + self.proj_bo
prob = tf.nn.softmax(logit)
logits.append(logit)
probs.append(prob) def attention(self, h_t, encoder_hs):
#scores = [tf.matmul(tf.tanh(tf.matmul(tf.concat(1, [h_t, tf.squeeze(h_s, [0])]),
# self.W_a) + self.b_a), self.v_a)
# for h_s in tf.split(0, self.max_size, encoder_hs)]
#scores = tf.squeeze(tf.pack(scores), [2])
scores = tf.reduce_sum(tf.mul(encoder_hs, h_t), 2)
a_t = tf.nn.softmax(tf.transpose(scores))
a_t = tf.expand_dims(a_t, 2)
c_t = tf.batch_matmul(tf.transpose(encoder_hs, perm=[1,2,0]), a_t)
c_t = tf.squeeze(c_t, [2])
h_tld = tf.tanh(tf.matmul(tf.concat(1, [h_t, c_t]), self.W_c) + self.b_c) return h_tld
参考文章:
https://www.slideshare.net/KeonKim/attention-mechanisms-with-tensorflow
https://github.com/dillonalaird/Attention/blob/master/attention.py
http://www.tuicool.com/articles/nUFRban
http://www.cnblogs.com/rocketfan/p/6261467.html
http://blog.csdn.net/jerr__y/article/details/53749693
从rnn到lstm,再到seq2seq(二)的更多相关文章
- RNN、LSTM、Seq2Seq、Attention、Teacher forcing、Skip thought模型总结
RNN RNN的发源: 单层的神经网络(只有一个细胞,f(wx+b),只有输入,没有输出和hidden state) 多个神经细胞(增加细胞个数和hidden state,hidden是f(wx+b) ...
- RNN、LSTM、Char-RNN 学习系列(一)
RNN.LSTM.Char-RNN 学习系列(一) zoerywzhou@gmail.com http://www.cnblogs.com/swje/ 作者:Zhouw 2016-3-15 版权声明 ...
- TensorFlow之RNN:堆叠RNN、LSTM、GRU及双向LSTM
RNN(Recurrent Neural Networks,循环神经网络)是一种具有短期记忆能力的神经网络模型,可以处理任意长度的序列,在自然语言处理中的应用非常广泛,比如机器翻译.文本生成.问答系统 ...
- RNN和LSTM
一.RNN 全称为Recurrent Neural Network,意为循环神经网络,用于处理序列数据. 序列数据是指在不同时间点上收集到的数据,反映了某一事物.现象等随时间的变化状态或程度.即数据之 ...
- 浅谈RNN、LSTM + Kreas实现及应用
本文主要针对RNN与LSTM的结构及其原理进行详细的介绍,了解什么是RNN,RNN的1对N.N对1的结构,什么是LSTM,以及LSTM中的三门(input.ouput.forget),后续将利用深度学 ...
- 3. RNN神经网络-LSTM模型结构
1. RNN神经网络模型原理 2. RNN神经网络模型的不同结构 3. RNN神经网络-LSTM模型结构 1. 前言 之前我们对RNN模型做了总结.由于RNN也有梯度消失的问题,因此很难处理长序列的数 ...
- RNN以及LSTM的介绍和公式梳理
前言 好久没用正儿八经地写博客了,csdn居然也有了markdown的编辑器了,最近花了不少时间看RNN以及LSTM的论文,在组内『夜校』分享过了,再在这里总结一下发出来吧,按照我讲解的思路,理解RN ...
- 深度学习:浅谈RNN、LSTM+Kreas实现与应用
主要针对RNN与LSTM的结构及其原理进行详细的介绍,了解什么是RNN,RNN的1对N.N对1的结构,什么是LSTM,以及LSTM中的三门(input.ouput.forget),后续将利用深度学习框 ...
- 利用RNN(lstm)生成文本【转】
本文转载自:https://www.jianshu.com/p/1a4f7f5b05ae 致谢以及参考 最近在做序列化标注项目,试着理解rnn的设计结构以及tensorflow中的具体实现方法.在知乎 ...
- Naive RNN vs LSTM vs GRU、attention基础
原文地址:https://www.jianshu.com/p/b8653f8b5b2b 一.Recurrent Neural Network 二.Naive RNN Naive RNN更新参数时易出现 ...
随机推荐
- BeanUtils工具类
用对象传参,用JavaBean传参. BeanUtils可以优化传参过程. 学习框架之后,BeanUtils的功能都由框架来完成. 一.为什么用BeanUtils? 每次我们的函数都要传递很多参数很麻 ...
- Redis入门到高可用(十九)——Redis Sentinel
一.Redis Sentinel架构 二.redis sentinel安装与配置 四.客户端连接Sentinel 四.实现原理—— 故障转移演练(客户端高可用) 五.实 ...
- alias用法
echo 'alias msfconsole="pushd $HOME/git/metasploit-framework && ./msfconsole && ...
- 解决IDEA无法安装插件的问题
进入2018年以来,在IDEA插件中心中,安装插件经常安装失败,报连接超时的错误.如下: 我们发现连接IDEA的插件中心使用的是https的链接,我们在浏览器中使用https访问插件中心并不能访问. ...
- C#中的装箱(inboxing)和拆箱(unboxing)(简单理解)
装箱和拆箱是值类型和引用类型之间相互转换是要执行的操作. 装箱:将一个值类型隐式地转换成一个object类型,或把这个值类型转换成一个被该值类型应用的接口类型,把一个值类型的值装箱,就是创建一个ob ...
- 【软件工程1916|W(福州大学)_助教博客】团队第一次作业成绩公示
题目 第一次作业 评分准则: 队名(最好能够体现项目内容,要求有亮点与个性):(1分) 拟作的团队项目描述:一句话(中英文不限):(1分) 队员风采:介绍每一名队员,包括成员性格.擅长的技术.编程的兴 ...
- JDK安装与配置(Windows 7系统)
1.前言 安装之前需弄清JDK.JRE.JVM这几个概念,不然稀里糊涂不知道自己在装什么. (1)什么是java环境:我们知道,想听音乐就要安装音乐播放器,想看图片需要安装图片浏览器,同样道理,要运行 ...
- Linux 配置SSH 无密钥登陆
根据SSH 协议,每次登陆必须输入密码,比较麻烦,SSH还提供了公钥登陆,可以省去输入密码的步骤. 公钥登陆:用户将自己的公钥存储在远程主机上,登陆的时候,远程主机会向用户发送一串随机字符串,用户用自 ...
- DataGridView 访问任意行不崩溃
int index= this.dataGridView1.rows.Add(); 先执行这行代码,然后访问任意行,不崩溃, 赋值不存在的行,只是不显示,或者无值. 什么原理呢? 一些其他 priva ...
- AARRR海盗模型简介
整理下AARRR模型的概念.实际应用场景等问题,初步感觉这个模型主要应用在APP应用分析中. 1.什么是AARRR模型 AARRR是Acquisition.Activation.Retention.R ...