Tensorflow动态seq2seq使用总结(r1.3)
https://www.jianshu.com/p/c0c5f1bdbb88
动机
其实差不多半年之前就想吐槽Tensorflow的seq2seq了(后面博主去干了些别的事情),官方的代码已经抛弃原来用静态rnn实现的版本了,而官网的tutorial现在还是介绍基于静态的rnn的模型,加bucket那套,看这里。
看到了吗?是legacy_seq2seq的。本来Tensorflow的seq2seq的实现相比于pytorch已经很复杂了,还没有个正经的tutorial,哎。
好的,回到正题,遇到问题解决问题,想办法找一个最佳的Tensorflow的seq2seq解决方案!
学习的资料
- 知名博主WildML给google写了个通用的seq2seq,文档地址,Github地址。这个框架已经被Tensorflow采用,后面我们的代码也会基于这里的实现。但本身这个框架是为了让用户直接写参数就能简单地构建网络,因此文档没有太多参考价值,我们直接借用其中的代码构建自己的网络。
- 俄罗斯小伙ematvey写的:tensorflow-seq2seq-tutorials,Github地址。介绍使用动态rnn构建seq2seq,decoder使用
raw_rnn
,原理和WildML的方案差不多。多说一句,这哥们当时也是吐槽Tensorflow的文档,写了那么个仓库当第三方的文档使,现在都400+个star了。真是有漏洞就有机遇啊,哈哈。
Tensorflow的动态rnn
先来简单介绍动态rnn和静态rnn的区别。
tf.nn.rnn creates an unrolled graph for a fixed RNN length. That means, if you call tf.nn.rnn with inputs having 200 time steps you are creating a static graph with 200 RNN steps. First, graph creation is slow. Second, you’re unable to pass in longer sequences (> 200) than you’ve originally specified.tf.nn.dynamic_rnn solves this. It uses a tf.While loop to dynamically construct the graph when it is executed. That means graph creation is faster and you can feed batches of variable size.
摘自Whats the difference between tensorflow dynamic_rnn and rnn?。也就是说,静态的rnn必须提前将图展开,在执行的时候,图是固定的,并且最大长度有限制。而动态rnn可以在执行的时候,将图循环地的复用。
一句话,能用动态的rnn就尽量用动态的吧。
Seq2Seq结构分析
seq2seq由Encoder和Decoder组成,一般Encoder和Decoder都是基于RNN。Encoder相对比较简单,不管是多层还是双向或者更换具体的Cell,使用原生API还是比较容易实现的。难点在于Decoder:不同的Decoder对应的rnn cell的输入不同,比如上图的示例中,每个cell的输入是上一个时刻cell输出的预测对应的embedding。
如果像上图那样使用Attention,则decoder的cell输入还包括attention加权求和过的context。
通过示例讲解
下面通过一个用seq2seq做slot filling(一种序列标注)的例子讲解。完整代码地址:https://github.com/applenob/RNN-for-Joint-NLU
Encoder的实现示例
# 首先构造单个rnn cell
encoder_f_cell = LSTMCell(self.hidden_size)
encoder_b_cell = LSTMCell(self.hidden_size)
(encoder_fw_outputs, encoder_bw_outputs),
(encoder_fw_final_state, encoder_bw_final_state) = \
tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_f_cell,
cell_bw=encoder_b_cell,
inputs=self.encoder_inputs_embedded,
sequence_length=self.encoder_inputs_actual_length,
dtype=tf.float32, time_major=True)
上面的代码使用了tf.nn.bidirectional_dynamic_rnn
构建单层双向的LSTM的RNN作为Encoder。
参数:
cell_fw
:前向的lstm cellcell_bw
:后向的lstm celltime_major
:如果是True,则输入需要是T×B×E,T代表时间序列的长度,B代表batch size,E代表词向量的维度。否则,为B×T×E。输出也是类似。
返回:
outputs
:针对所有时间序列上的输出。final_state
:只是最后一个时间节点的状态。
一句话,Encoder的构造就是构造一个RNN,获得输出和最后的状态。
Decoder实现示例
下面着重介绍如何使用Tensorflow的tf.contrib.seq2seq
实现一个Decoder。
我们这里的Decoder中,每个输入除了上一个时间节点的输出以外,还有对应时间节点的Encoder的输出,以及attention的context。
Helper
常用的Helper
:
TrainingHelper
:适用于训练的helper。InferenceHelper
:适用于测试的helper。GreedyEmbeddingHelper
:适用于测试中采用Greedy策略sample的helper。CustomHelper
:用户自定义的helper。
先来说明helper是干什么的:参考上面提到的俄罗斯小哥用raw_rnn
实现decoder,需要传进一个loop_fn
。这个loop_fn
其实是控制每个cell在不同的时间节点,给定上一个时刻的输出,如何决定下一个时刻的输入。
helper干的事情和这个loop_fn
基本一致。这里着重介绍CustomHelper
,要传入三个函数作为参数:
initialize_fn
:返回finished
,next_inputs
。其中finished
不是scala,是一个一维向量。这个函数即获取第一个时间节点的输入。sample_fn
:接收参数(time, outputs, state)
返回sample_ids
。即,根据每个cell的输出,如何sample。next_inputs_fn
:接收参数(time, outputs, state, sample_ids)
返回(finished, next_inputs, next_state)
,根据上一个时刻的输出,决定下一个时刻的输入。
BasicDecoder
有了自定义的helper以后,可以使用tf.contrib.seq2seq.BasicDecoder
定义自己的Decoder了。再使用tf.contrib.seq2seq.dynamic_decode
执行decode,最终返回:(final_outputs, final_state, final_sequence_lengths)
。其中:final_outputs
是tf.contrib.seq2seq.BasicDecoderOutput
类型,包括两个字段:rnn_output
,sample_id
。
回到示例
# 传给CustomHelper的三个函数
def initial_fn():
initial_elements_finished = (0 >= decoder_lengths) # all False at the initial step
initial_input = tf.concat((sos_step_embedded, encoder_outputs[0]), 1)
return initial_elements_finished, initial_input
def sample_fn(time, outputs, state):
# 选择logit最大的下标作为sample
prediction_id = tf.to_int32(tf.argmax(outputs, axis=1))
return prediction_id
def next_inputs_fn(time, outputs, state, sample_ids):
# 上一个时间节点上的输出类别,获取embedding再作为下一个时间节点的输入
pred_embedding = tf.nn.embedding_lookup(self.embeddings, sample_ids)
# 输入是h_i+o_{i-1}+c_i
next_input = tf.concat((pred_embedding, encoder_outputs[time]), 1)
elements_finished = (time >= decoder_lengths) # this operation produces boolean tensor of [batch_size]
all_finished = tf.reduce_all(elements_finished) # -> boolean scalar
next_inputs = tf.cond(all_finished, lambda: pad_step_embedded, lambda: next_input)
next_state = state
return elements_finished, next_inputs, next_state
# 自定义helper
my_helper = tf.contrib.seq2seq.CustomHelper(initial_fn, sample_fn, next_inputs_fn)
def decode(helper, scope, reuse=None):
with tf.variable_scope(scope, reuse=reuse):
memory = tf.transpose(encoder_outputs, [1, 0, 2])
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
num_units=self.hidden_size, memory=memory,
memory_sequence_length=self.encoder_inputs_actual_length)
cell = tf.contrib.rnn.LSTMCell(num_units=self.hidden_size * 2)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell, attention_mechanism, attention_layer_size=self.hidden_size)
out_cell = tf.contrib.rnn.OutputProjectionWrapper(
attn_cell, self.slot_size, reuse=reuse
)
# 使用自定义helper的decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=out_cell, helper=helper,
initial_state=out_cell.zero_state(
dtype=tf.float32, batch_size=self.batch_size))
# 获取decode结果
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
decoder=decoder, output_time_major=True,
impute_finished=True, maximum_iterations=self.input_steps
)
return final_outputs
outputs = decode(my_helper, 'decode')
Attntion
上面的代码,还有几个地方没有解释:BahdanauAttention
,AttentionWrapper
,OutputProjectionWrapper
。
先从简单的开始:OutputProjectionWrapper
即做一个线性映射,比如之前的cell的ouput是T×B×D,D是hidden size,那么这里做一个线性映射,直接到T×B×S,这里S是slot class num。wrapper内部维护一个线性映射用的变量:W
和b
。
BahdanauAttention
是一种AttentionMechanism
,另外一种是:BahdanauMonotonicAttention
。具体二者的区别,读者请自行深入调查。关键参数:
num_units
:隐层维度。memory
:通常就是RNN encoder的输出memory_sequence_length=None
:可选参数,即memory的mask,超过长度数据不计入attention。
继续介绍AttentionWrapper
:这也是一个cell wrapper,关键参数:
cell
:被包装的cell。attention_mechanism
:使用的attention机制,上面介绍的。
memory对应公式中的h,wrapper的输出是s。
那么一个AttentionWrapper
具体的操作流程如何呢?看官网给的流程:
Loss Function
tf.contrib.seq2seq.sequence_loss
可以直接计算序列的损失函数,重要参数:
logits
:尺寸[batch_size, sequence_length, num_decoder_symbols]
targets
:尺寸[batch_size, sequence_length]
,不用做one_hot。weights
:[batch_size, sequence_length]
,即mask,滤去padding的loss计算,使loss计算更准确。
后记
这里只讨论了seq2seq在序列标注上的应用。seq2seq还广泛应用于翻译和对话生成,涉及到生成的策略问题,比如beam search。后面会继续研究。除了sample的策略,其他seq2seq的主要技术,本文已经基本涵盖,希望对大家踩坑有帮助。
完整代码:https://github.com/applenob/RNN-for-Joint-NLU
作者:Cer_ml
链接:https://www.jianshu.com/p/c0c5f1bdbb88
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
Tensorflow动态seq2seq使用总结(r1.3)的更多相关文章
- tensorflow动态设置trainable
tensorflow中定义的tf.Variable时,可以通过trainable属性控制这个变量是否可以被优化器更新.但是,tf.Variable的trainable属性是只读的,我们无法动态更改这个 ...
- tensorflow 笔记13:了解机器翻译,google NMT,Attention
一.关于Attention,关于NMT 未完待续... 以google 的 nmt 代码引入 探讨下端到端: 项目地址:https://github.com/tensorflow/nmt 机器翻译算是 ...
- 解析Tensorflow官方English-Franch翻译器demo
今天我们来解析下Tensorflow的Seq2Seq的demo.继上篇博客的PTM模型之后,Tensorflow官方也开放了名为translate的demo,这个demo对比之前的PTM要大了很多(首 ...
- tensorflow读取训练数据方法
1. 预加载数据 Preloaded data # coding: utf-8 import tensorflow as tf # 设计Graph x1 = tf.constant([2, 3, 4] ...
- windows10下如何进行源码编译安装tensorflow
1.获取python3.5.x https://www.python.org/ftp/python/3.5.4/python-3.5.4-amd64.exe 2.安装python3.5.x,默认安装即 ...
- keras系列︱seq2seq系列相关实现与案例(feedback、peek、attention类型)
之前在看<Semi-supervised Sequence Learning>这篇文章的时候对seq2seq半监督的方式做文本分类的方式产生了一定兴趣,于是开始简单研究了seq2seq.先 ...
- 路由器基本配置实验,静态路由和动态RIP路由
实验涉及命令以及知识补充 连线 PC和交换机FastEtherNet接口 交换机和路由器FastEtherNet接口 路由器和路由器Serial接口 serial是串行口,一般用于连接设备,不能连接电 ...
- bert+seq2seq 周公解梦,看AI如何解析你的梦境?【转】
介绍 在参与的项目和产品中,涉及到模型和算法的需求,主要以自然语言处理(NLP)和知识图谱(KG)为主.NLP涉及面太广,而聚焦在具体场景下,想要生产落地的还需要花很多功夫. 作为NLP的主要方向,情 ...
- tensorflow sequence_loss
sequence_loss是nlp算法中非常重要的一个函数.rnn,lstm,attention都要用到这个函数.看下面代码: # coding: utf-8 import numpy as np i ...
随机推荐
- hdu 5762 Teacher Bo 暴力
Teacher Bo 题目连接: http://acm.hdu.edu.cn/showproblem.php?pid=5762 Description Teacher BoBo is a geogra ...
- Codeforces Round #517 (Div. 2, based on Technocup 2019 Elimination Round 2)
Codeforces Round #517 (Div. 2, based on Technocup 2019 Elimination Round 2) #include <bits/stdc++ ...
- 解决Windows x86网易云音乐不能将音乐下载到SD卡的BUG
由于我个人最常用的电脑是Surface pro4 256G版本,装了不少生产力空间还挺吃紧的,音乐之类的必然都存单独的SD卡里.用UWP版本的网易云音乐倒是没问题,最近问题来了,UWP版本的网易云音乐 ...
- Spring_之注解事务 @Transactional
spring 事务注解 默认遇到throw new RuntimeException("...");会回滚需要捕获的throw new Exception("...&qu ...
- IBM MR10i阵列卡配置Raid0/Raid1/Raid5(转)
RAID5配置: 其实RAID0/RAID1都基本一致,只是选择的类型不同. 1. 开机看到ctrl+h的提示按下相应的键,等ServerRaid 10-i卡初始化完成则进入WebBIOS 配置界面: ...
- VMware安装MikroTik RouterOS chr
简单步骤: 1.官网下载ova镜像 2.导入到vmware即可.
- Struts2 高危漏洞修复方案 (S2-016/S2-017)
近期Struts2被曝重要漏洞,此漏洞影响struts2.0-struts2.3所有版本,可直接导致服务器被远程控制从而引起数据泄漏,影响巨大,受影响站点以电商.银行.门户.政府居多. 官方描述:S2 ...
- 【Android基础篇】TabWidget设置背景和字体
在使用TabHost实现底部导航栏时,底部导航栏的三个导航button无法在布局文件中进行定制.比方设置点击时的颜色.字体的大小及颜色等,这里提供了一个解决的方法.就是在代码里进行定制. 思路是在Ac ...
- Android:活动的启动模式
启动模式一共有四种,分别是 standard .singleTop . singleTask 和 singleInstance , 可 以 在 AndroidManifest.xml 中 通 过 给 ...
- 无需SherlockActionbar的SlidingMenu使用详解(二)——向Fragment中添加ViewPager和Tab
之前我们对大体框架有了一定的认识,现在我们来做Fragment界面,其实这里面和这个框架的关系就不大了,但因为有些同学对于在SlidingMenu中切换fragment还是有问题,所以我就在本篇进行详 ...