新版seqseq接口说明
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=FLAGS.rnn_hidden_size, memory = encoder_outputs, memory_sequence_length = encoder_sequence_length)
这一步创造一个attention_mechanism。通过__call__(self, query, previous_alignments)来调用,输入query也就是decode hidden,输入previous_alignments是encode hidden,输出是一个attention概率矩阵
helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(inputs, tf.to_int32(sequence_length), emb, tf.constant(FLAGS.scheduled_sampling_probability))
创建一个helper,用来处理每个时刻的输入和输出
my_decoder = tf.contrib.seq2seq.BasicDecoder(cell = cell, helper = helper, initial_state = state)
调用的核心部分。通过def step(self, time, inputs, state, name=None)来控制每一个进行decode
首先把inputs和attention进行concat作为输入。(为什么这样做,参考LSTM的实现 W1U+W2V,其实是把U,V concat在乘以一个W),那么这里inputs就是U,attention就是V(其实tf.concat(query,attention矩阵 * memory)在做个outpreject)。
outputs, state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope='seq_decode')
最后通过dynamic_decode来控制整个flow
写到前面:
先看:
class BasicRNNCell(RNNCell):
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
if self._linear is None:
self._linear = _Linear([inputs, state], self._num_units, True)
这个是核心,也就是W * input + U * state + B的实现,tf是用_Linear来实现的(_Linear的实现就是把input和state进行concat,然后乘以一个W)。由于rnn只有hidden,所以这里的state就是hidden
再看
class BasicLSTMCell(RNNCell):
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)
new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
就非常明显了,由于lstm的state是由两部分构成的,一个是hidden,一个是state,第一步先split。之后用inputs和h进行linear,由于我们要输出4个结果,记得输出维度一定要是4*_num_units。然后根据公式再进行后面的操作,最后返回新的hidden和state,也很直观。
之后再看,加入attention之后怎么弄:
我们这里的attention为encode hidden,那么根据公式是attention和decode hidden进行concat作为一个大的hidden,之后和inputs一起进入网络。
但是,tf实现的时候是这样子的,首先把attention和inputs进行concat,之后把连接的结果作为inputs和decode hidden一起送入网络。为什么能这么做呢,是因为在网络内部其实也是concat之后再linear,参考上面的BasicLSTMCell实现,所有关键就是把(inputs,attention,decode hidden)concat一起就行了,不管顺序是啥。说道这里你终于明白了AttentionWrapper到底是干啥的了。那么attention怎么计算呢,有个_compute_attention函数。我感觉就是非常直接了,attention_mechanism是你需要的attention映射矩阵的方式,
def _compute_attention(attention_mechanism, cell_output, previous_alignments,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments = attention_mechanism(
cell_output, previous_alignments=previous_alignments)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments
新版seqseq接口说明的更多相关文章
- 虹软最新版 python 接口 完整版
虹软最新版 python 接口 完整版 当前开源的人脸检测模型,识别很多,很多小伙伴也踩过不少坑.相信不少使用过dlib和facenet人脸识别的小伙伴都有这样的疑惑,为什么论文里高达99.8以上的准 ...
- javascript使用H5新版媒体接口navigator.mediaDevices.getUserMedia,做扫描二维码,并识别内容
本文代码测试要求,最新的chrome浏览器(手机APP),并且要允许chrome拍照录像权限,必须要HTTPS协议,http不支持. 原理:调用摄像头,将摄像头返回的媒体流渲染到视频标签中,再通过ca ...
- 夺命雷公狗---微信开发55----微信js-sdk接口开发(2)接口功能介绍之签名算法
我们JS-SDK里面其实有不少的接口 startRecord---录音 stopRecord---停止录音 playVoice---播放 pauseVoice---暂停播放 uploadImage-- ...
- 使用Github Pages建独立博客
http://beiyuu.com/github-pages/ Github很好的将代码和社区联系在了一起,于是发生了很多有趣的事情,世界也因为他美好了一点点.Github作为现在最流行的代码仓库,已 ...
- 微信:JSSDK开发
根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”. JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引 ...
- 微信公众平台开发 微信JSSDK开发
根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”. JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引 ...
- 【周年版】Cnblogs for Android
前言 扒衣见君节刚过去但是炎热夏天还在继续: 自14年8月推出博客园Android客户端以来,断断续续发了十几个后续版本,期间出现过各种问题,由于接口等诸多因素,每个模块的功能都可能随着时间和博客园主 ...
- 微信公众平台JSSDK开发
根据微信开发文档步骤如下: 1.先登录微信公众平台进入“公众号设置”的“功能设置”里填写“JS接口安全域名”.JS接口安全域名设置 mi.com(前面不用带www/http,域名必须备案过) 2.引入 ...
- 微信JS-SDK
<div class="lbox_close wxapi_form"> <h3 id="menu-basic">基础接口</h3& ...
随机推荐
- java 访问数据库
Class.forName(“com.microsoft.sqlserver.jdbc.SQLServerDriver”);//依据不同数据库,加载不同驱动 String url = “jdbc:sq ...
- 用Python3实现的Mycin专家系统简单实例
from sys import stderr ######################### TRUE = 1 #定义返回值 FALSE = 0 FACT_LENGTH = 9 #'''前提与结论 ...
- JVM内存回收机制——哪些内存需要被回收(JVM学习系列2)
上一篇文章中讨论了Java内存运行时的各个区域,其中程序计数器.虚拟机栈.本地方法栈随线程生灭,且创建时需要多少内存,基本上在译期间就决定的了,所以在内存回收时无需特殊的关注.而堆和方法区则不同,首先 ...
- Linux基础命令---sar显示系统活动信息
sar sar指令用来收集.报告.保存系统的活动信息.sar命令将操作系统中选定的累积活动计数器的内容写入标准输出.会计系统根据参数“interval”.“count”中的值,写入以秒为单位的指定间隔 ...
- PAT (Advanced Level) Practice 1001 A+B Format (20 分)
题目链接:https://pintia.cn/problem-sets/994805342720868352/problems/994805528788582400 Calculate a+b and ...
- 【融云分析】如何实现分布式场景下唯一 ID 生成?
◀背景▶ 对于一套分布式部署的 IM 系统,要求每条消息的 ID 要保证在集群中全局唯一且按生成时间有序排列.如何快速高效的生成消息数据的唯一 ID ,是影响系统吞吐量的关键因素.那么,融云是如何做到 ...
- C++ 创建快捷方式
https://blog.csdn.net/morewindows/article/details/6686683
- hdu1172(枚举)
中文题,题意就不解释了. 思路:因为答案一定是四位数,所以只要枚举1000-9999,如果符合所有条件,那么保存一下答案,记录一下答案的个数,如果答案是唯一的,那么输出它,否则,就不确定. 代码如下: ...
- 关于用IIS在.net平台发布网页的一些坑
说明:由于需要显示页面的表格的内容,要用pageOffice插件,而装pageoffice之前需要装.net3.5,直接导入. 为什么要分别装.net4.5和.net3.5 ? 都要装? 问题:刚才 ...
- 原创《weex面向未来的架构》
最近一直在做weex的调研工作,整理之后给公司做了一次技术分享. 分享内容如下: 1:Weex是什么? 2: Weex目前能做什么? 3: Weex 如何调试 4: 剖析一下Weex原理 5: ...