attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=FLAGS.rnn_hidden_size, memory = encoder_outputs, memory_sequence_length = encoder_sequence_length)

这一步创造一个attention_mechanism。通过__call__(self, query, previous_alignments)来调用,输入query也就是decode hidden,输入previous_alignments是encode hidden,输出是一个attention概率矩阵

helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(inputs, tf.to_int32(sequence_length), emb, tf.constant(FLAGS.scheduled_sampling_probability))

创建一个helper,用来处理每个时刻的输入和输出

my_decoder = tf.contrib.seq2seq.BasicDecoder(cell = cell, helper = helper, initial_state = state)

调用的核心部分。通过def step(self, time, inputs, state, name=None)来控制每一个进行decode

首先把inputs和attention进行concat作为输入。(为什么这样做,参考LSTM的实现 W1U+W2V,其实是把U,V concat在乘以一个W),那么这里inputs就是U,attention就是V(其实tf.concat(query,attention矩阵 * memory)在做个outpreject)。

outputs, state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope='seq_decode')

最后通过dynamic_decode来控制整个flow

写到前面:

先看:

class BasicRNNCell(RNNCell):

def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
if self._linear is None:
self._linear = _Linear([inputs, state], self._num_units, True)

这个是核心,也就是W * input + U * state + B的实现,tf是用_Linear来实现的(_Linear的实现就是把input和state进行concat,然后乘以一个W)。由于rnn只有hidden,所以这里的state就是hidden

再看

class BasicLSTMCell(RNNCell):

if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state

就非常明显了,由于lstm的state是由两部分构成的,一个是hidden,一个是state,第一步先split。之后用inputs和h进行linear,由于我们要输出4个结果,记得输出维度一定要是4*_num_units。然后根据公式再进行后面的操作,最后返回新的hidden和state,也很直观。

之后再看,加入attention之后怎么弄:

我们这里的attention为encode hidden,那么根据公式是attention和decode hidden进行concat作为一个大的hidden,之后和inputs一起进入网络。

但是,tf实现的时候是这样子的,首先把attention和inputs进行concat,之后把连接的结果作为inputs和decode hidden一起送入网络。为什么能这么做呢,是因为在网络内部其实也是concat之后再linear,参考上面的BasicLSTMCell实现,所有关键就是把(inputs,attention,decode hidden)concat一起就行了,不管顺序是啥。说道这里你终于明白了AttentionWrapper到底是干啥的了。那么attention怎么计算呢,有个_compute_attention函数。我感觉就是非常直接了,attention_mechanism是你需要的attention映射矩阵的方式,

def _compute_attention(attention_mechanism, cell_output, previous_alignments,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments = attention_mechanism(
cell_output, previous_alignments=previous_alignments)

# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])

if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context

return attention, alignments

新版seqseq接口说明的更多相关文章

  1. 虹软最新版 python 接口 完整版

    虹软最新版 python 接口 完整版 当前开源的人脸检测模型,识别很多,很多小伙伴也踩过不少坑.相信不少使用过dlib和facenet人脸识别的小伙伴都有这样的疑惑,为什么论文里高达99.8以上的准 ...

  2. javascript使用H5新版媒体接口navigator.mediaDevices.getUserMedia,做扫描二维码,并识别内容

    本文代码测试要求,最新的chrome浏览器(手机APP),并且要允许chrome拍照录像权限,必须要HTTPS协议,http不支持. 原理:调用摄像头,将摄像头返回的媒体流渲染到视频标签中,再通过ca ...

  3. 夺命雷公狗---微信开发55----微信js-sdk接口开发(2)接口功能介绍之签名算法

    我们JS-SDK里面其实有不少的接口 startRecord---录音 stopRecord---停止录音 playVoice---播放 pauseVoice---暂停播放 uploadImage-- ...

  4. 使用Github Pages建独立博客

    http://beiyuu.com/github-pages/ Github很好的将代码和社区联系在了一起,于是发生了很多有趣的事情,世界也因为他美好了一点点.Github作为现在最流行的代码仓库,已 ...

  5. 微信:JSSDK开发

    根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”. JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引 ...

  6. 微信公众平台开发 微信JSSDK开发

    根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”. JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引 ...

  7. 【周年版】Cnblogs for Android

    前言 扒衣见君节刚过去但是炎热夏天还在继续: 自14年8月推出博客园Android客户端以来,断断续续发了十几个后续版本,期间出现过各种问题,由于接口等诸多因素,每个模块的功能都可能随着时间和博客园主 ...

  8. 微信公众平台JSSDK开发

    根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”.JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引入 ...

  9. 微信JS-SDK

    <div class="lbox_close wxapi_form"> <h3 id="menu-basic">基础接口</h3& ...

随机推荐

  1. 论文速读(Chuhui Xue——【arxiv2019】MSR_Multi-Scale Shape Regression for Scene Text Detection)

    Chuhui Xue--[arxiv2019]MSR_Multi-Scale Shape Regression for Scene Text Detection 论文 Chuhui Xue--[arx ...

  2. Gvim:unable to load python

    环境 系统win7 64 bit 软件: Gvim8.1 : MS-Windows 32bit 软件: python2.7.14 windows 64bit 问题 点击打开Gvim时,提示:unabl ...

  3. laravel5.7 前后端分离开发 实现基于API请求的token认证

    最近在学习前后端分离开发,发现 在laravel中实现前后台分离是无法无法使用 CSRF Token 认证的.因为 web 请求的用户认证是通过Session和客户端Cookie的实现的,而前后端分离 ...

  4. Fetch和ajax的比较和区别

    传统 Ajax 已死,Fetch 永生   Ajax 不会死,传统 Ajax 指的是 XMLHttpRequest(XHR),未来现在已被 Fetch 替代. 最近把阿里一个千万级 PV 的数据产品全 ...

  5. python-支付宝支付示例

      项目演示: 1.输入金额 2.扫码支付: 3.支付完成: 4.跳转回商户 一.注册账号 https://openhome.alipay.com/platform/appDaily.htm?tab= ...

  6. 打开visual studio 2010报错:未能正确加载“VSTS for Database Professionals Sql Server Data-tier Application”包

    1  解决: 运行cmd 2  输入:regsvr32 %windir%\system32\jscript.dll

  7. 修改Aptana Studio默认编码

    1,修改:Text  file encoding 2,修改:Initial HTML file contents

  8. Task: Indoor Positioning with WiFi Signals

    Task: Indoor Positioning with WiFi SignalsYou are hired by a company to design an indoor localizatio ...

  9. PL/SQL变量和类型

    变量 在定义变量时一定要为其指定一个类型,类型可以是PL/SQL类型或SQL语言的类型,一旦变量的类型确定,那么变量中所能存储的值也就确定了,因此尽管变量的值会经常改变,但是值的类型是不可以变化的. ...

  10. MySQL触发器在建立时,报语法错的问题

    delimiter $$ create trigger trg_delete_on_users before DELETE on users for each row begin delete fro ...