seq2seq:

seq2seq就是将输入序列经过encoder-decoder变成目标序列。

如图所示,输入序列是 [A, B, C, <EOS>],输出序列是  [W, X, Y, Z, <EOS>]

encoder-decoder:

主要过程就是用RNN对输入序列进行编码,然后再用RNN对上下文向量进行解码。

实现方式:

1、tf.nn.dynamic_rnn

    参考:https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/1-seq2seq.ipynb

流程:

输入序列: [A,B,C,EOS],其中A,B,C, EOS都要进行embedding,encoder部分的代码如下所示:

encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)

encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
encoder_cell, encoder_inputs_embedded,
dtype=tf.float32, time_major=True,
)

 encoder_outputs是一个时间步的输出,这个在decoder中用不到。encoder_final_stata是最后一层的输出结果,encoder_final_state是一个二元组,(整体的记忆c,隐藏层状态h),然后用encoder_final_state来初始化decoder的状态,而decoder的输入序列为 [EOS, A, B, C],因为dynamic_rnn不能根据上一步的输出来作为当前的输入,所以对于输入来说是固定,而非动态变化的。

decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)

decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
decoder_cell, decoder_inputs_embedded, initial_state=encoder_final_state, dtype=tf.float32, time_major=True, scope="plain_decoder",
)

2、tf.nn.raw_rnn

这种方法不像dynamic_rnn那样固定,它比较灵活,可以通过迭代函数改变每一个时间步的 输入状态、输入。

offical document

tf.nn.raw_rnn(
cell, //基础神经元
loop_fn, //迭代函数,每次的状态与输入都可以在这里定义
parallel_iterations=None,
swap_memory=False,
scope=None
)
输出:
(emit_ta, final_state, final_loop_state),其中emit_ta是TensorArray类型,其实就是每一个时间步输出的tensor的数组,final_state最后的状态,final_loop_state这个好像是None,不知道啥作用

实现步骤

整体:

decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn) //decoder_cell是基础神经单元,loop_fn是迭代函数

迭代函数:

//迭代函数包含time, previous_output, previous_state, previous_loop_state(这个相当于LSTM中那个全局的记忆)
def loop_fn(time, previous_output, previous_state, previous_loop_state):
if previous_state is None: # time == 0, 初始化
assert previous_output is None and previous_state is None
return loop_fn_initial()
else:
return loop_fn_transition(time, previous_output, previous_state, previous_loop_state) //在上一个时间步结束后,即将进入当前时间步时会执行该函数,目的就是确定要将哪些内容传给下一步作为状态输入和输入向量

初始化函数:

def loop_fn_initial():
initial_elements_finished = (0 >= decoder_lengths) # all False at the initial step
initial_input = eos_step_embedded #第一步的输入是EOS
initial_cell_state = encoder_final_state #状态输入就是encoder的最终输出状态,包括(c,h)
initial_cell_output = None
initial_loop_state = None # we don't need to pass any additional information
return (initial_elements_finished,
initial_input,
initial_cell_state,
initial_cell_output,
initial_loop_state)

迭代函数:

def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
#如何获取上一步的输出
yhat = softmax(previous_output * W + b)
然后概率最大的那个yhat即为上一步的输出结果,并对这个结果进行embedding,作为下一步的输入 def get_next_input():
output_logits = tf.add(tf.matmul(previous_output, W), b)
prediction = tf.argmax(output_logits, axis=)
next_input = tf.nn.embedding_lookup(embeddings, prediction)
return next_input

#判断是否停止,常数 >= tensor向量,tensor中每个位置都要和常数进行比较,结果是一个布尔型的tensor向量
elements_finished = (time >= decoder_lengths) # this operation produces boolean tensor of [batch_size]
# defining if corresponding sequence has ended
#因为这是一个batch块,所以该batch完成的标志是 所有的item都finish,所以需要reduce_all
finished = tf.reduce_all(elements_finished) # -> boolean scalar
#当前步的输入 = 上一步的输出(get_next_input)
#tf.cond(条件,True时调用的函数, False时调用的函数)
input = tf.cond(finished, lambda: pad_step_embedded, get_next_input)
state = previous_state #状态不用改变直接传过去
output = previous_output #previous_output也不用变,好像这个output是一个TensorArray吧?
loop_state = None return (elements_finished,
input,
state,
output,
loop_state)

调用过程:

decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(decoder_cell, loop_fn)

 这样就实现了将上一步decoder出来的结果作为下一步的输入,真正实现上图中的过程。

待补充Attention机制

参考:

https://github.com/ematvey/tensorflow-seq2seq-tutorials

https://hanxiao.github.io/2017/08/16/Why-I-use-raw-rnn-Instead-of-dynamic-rnn-in-Tensorflow-So-Should-You-0/

seq2seq的更多相关文章

  1. DL4NLP —— seq2seq+attention机制的应用:文档自动摘要(Automatic Text Summarization)

    两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...

  2. 深度学习之seq2seq模型以及Attention机制

    RNN,LSTM,seq2seq等模型广泛用于自然语言处理以及回归预测,本期详解seq2seq模型以及attention机制的原理以及在回归预测方向的运用. 1. seq2seq模型介绍 seq2se ...

  3. 深度学习之 seq2seq 进行 英文到法文的翻译

    深度学习之 seq2seq 进行 英文到法文的翻译 import os import torch import random source_path = "data/small_vocab_ ...

  4. ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]

    ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...

  5. ChatGirl is an AI ChatBot based on TensorFlow Seq2Seq Model

    Introduction [Under developing,it is not working well yet.But you can just train,and run it.] ChatGi ...

  6. tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码

    #!/usr/bin/env python # -*- coding: utf-8 -*- import tensorflow as tf import numpy as np params=np.r ...

  7. 深度学习之注意力机制(Attention Mechanism)和Seq2Seq

    这篇文章整理有关注意力机制(Attention Mechanism )的知识,主要涉及以下几点内容: 1.注意力机制是为了解决什么问题而提出来的? 2.软性注意力机制的数学原理: 3.软性注意力机制. ...

  8. Pytorch系列教程-使用Seq2Seq网络和注意力机制进行机器翻译

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutor ...

  9. [转] 图解Seq2Seq模型、RNN结构、Encoder-Decoder模型 到 Attention

    from : https://caicai.science/2018/10/06/attention%E6%80%BB%E8%A7%88/ 一.Seq2Seq 模型 1. 简介 Sequence-to ...

随机推荐

  1. Android开发支付集成——微信集成

    支付宝支付传送门:https://www.cnblogs.com/dingxiansen/p/9208949.html 二.微信支付 1. 微信支付流程图 相比较而言,微信支付是要比支付宝麻烦一些,并 ...

  2. vue 外部字体图标使用,无须绝对路径引入办法

    通常外部字体图标都在使用 iconfont ,这种图标在网上搜到一大把都是由于路径问题显示不出来,或者是显示个方块. 最近的项目中也碰到这个坑爸的问题,总结一下解决办法: 和 webpack.conf ...

  3. Java关于日期的计算持续汇总~

    /** * 00 * 描述:传入Date date.转为 String yyyyMMdd. * [时间 2019-04-18 15:41:12 作者 陶攀峰] */ public static Str ...

  4. Redmine入门-安装

    Redmine提供了两种方式安装,如果仅仅只是使用Redmine,建议采用一键安装的方式,快捷方便.如果需要做二次开发或者更多的个性化处理,可以采用源码安装方式,下面分别介绍两种安装方式. ----- ...

  5. Django-2- 模板路径查找,模板变量,模板过滤器,静态文件引用

    模板路径查找 路径配置 2. templates模板查找有两种方式 2.1 - 在APP目录下创建templates文件夹,在文件夹下创建模板 2.2 - 在项目根目录下创建templates文件夹, ...

  6. 网络浅析(<<网络是怎么连接的>> 总结)

    概要 基本概念 网线 集线器 交换机 路由器 路由器和交换机 路由器和集线器 接入网 IP DNS 以太网 协议栈 网络连接过程 通信过程(浏览器 -> 服务器) 客户端和服务端 服务端的套接字 ...

  7. vim编辑器操作

    vim被称为编辑器之神,另外一个是sublime.vim较vi比较高级,vi适用于文本编辑,vim更加适合于coding.凡是vim里面的命令在vi都是适用的. vim的大众版的三种模式(其实不止三种 ...

  8. Webstorm 2017.3激活破解

    之前尝试过各种激活破解办法,不过随着版本的不断升级,激活信息都失效了(毕竟咱不是通过正常途径激活的),只能重新激活.而且难度越来越大,记得早先网上有人分享激活码,激活的server地址,破解程序等等, ...

  9. Java基础系列--06_抽象类与接口概述

    抽象类 (1)如果多个类中存在相同的方法声明,而方法体不一样,我们就可以只提取方法声明. 如果一个方法只有方法声明,没有方法体,那么这个方法必须用抽象修饰. 而一个类中如果有抽象方法,这个类必须定义为 ...

  10. SQLAchemy模块

    老师的博客:http://www.cnblogs.com/wupeiqi/articles/5713330.html 有一篇习详细的博客: http://www.keakon.net/2012/12/ ...