pointer-network是最近seq2seq比较火的一个分支,在基于深度学习的阅读理解,摘要系统中都被广泛应用。

感兴趣的可以阅读原paper 推荐阅读

https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264

 
 

这个思路也是比较简单
就是解码的预测限定在输入的位置上
这在很多地方有用

比如考虑机器翻译的大词典问题,词汇太多了很多词是长尾的,词向量训练是不充分的,那么seq2seq翻译的时候很难翻译出这些词
另外专名什么的
很多是可以copy到
解码输出的

另外考虑文本摘要,很多时候就是要copy输入原文中的词,特别是长尾专名
更好的方式是copy而不是generate

 
 

网络上有一些pointer-network的实现,比较推荐

 https://github.com/ikostrikov/TensorFlow-Pointer-Networks

这个作为入门示例比较好,使用简单的static rnn 实现更好理解,当然 dynamic速度更快,但是从学习角度

先实现static更好一些。

Dynamic rnn的 pointer network实现

https://github.com/devsisters/pointer-network-tensorflow 

这里对static rnn实现的做了一个拷贝并做了小修改,改正了其中的一些问题
参见
https://github.com/chenghuige/hasky/tree/master/applications/pointer-network/static

 
 

这个小程序对应的应用是输入一个序列
比如,输出排序结果

 
 

我们的构造数据

python dataset.py

EncoderInputs: [array([[ 0.74840968]]), array([[ 0.70166106]]), array([[ 0.67414996]]), array([[ 0.9014052]]), array([[ 0.72811645]])]

DecoderInputs: [array([[ 0.]]), array([[ 0.67414996]]), array([[ 0.70166106]]), array([[ 0.72811645]]), array([[ 0.74840968]]), array([[ 0.9014052]])]

TargetLabels: [array([[ 3.]]), array([[ 2.]]), array([[ 5.]]), array([[ 1.]]), array([[ 4.]]), array([[ 0.]])]

 
 

训练过程中的eval展示:

2017-06-07 22:35:52 0:28:19 eval_step: 111300 eval_metrics:

['eval_loss:0.070', 'correct_predict_ratio:0.844']

label--: [ 2 6 1 4 9 7 10 8 5 3 0]

predict: [ 2 6 1 4 9 7 10 8 5 3 0]

label--: [ 1 6 2 5 8 3 9 4 10 7 0]

predict: [ 1 6 2 5 3 3 9 4 10 7 0]

 
 

大概是这样
第一个我们认为是预测完全正确了,
第二个预测不完全正确

 
 

原程序最主要的问题是 Feed_prev 设置为True的时候 原始代码有问题的 因为inp使用的是decoder_input这是不正确的因为

预测的时候其实是没有decoder_input输入的,原代码预测的时候decoder input强制copy/feed了encoder_input

这在逻辑是是有问题的。 实验效果也证明修改成训练也使用encoder_input来生成inp效果好很多。

 
 

那么关于feed_prev我们知道在预测的时候是必须设置为True的因为,预测的时候没有decoder_input我们的下一个输出依赖

上一个预测的输出。

训练的时候我们是用decoder_input序列训练(feed_prev==False)还是也使用自身预测产生的结果进行下一步预测feed_prev==True呢

参考tensorflow官网的说明

In the above invocation, we set feed_previous to False. This means that the decoder will use decoder_inputstensors as provided. If we set feed_previous to True, the decoder would only use the first element of decoder_inputs. All other tensors from this list would be ignored, and instead the previous output of the decoder would be used. This is used for decoding translations in our translation model, but it can also be used during training, to make the model more robust to its own mistakes, similar to Bengio et al., 2015 (pdf).

 
 

来自 <https://www.tensorflow.org/tutorials/seq2seq>

 
 

这里使用

train.sh

train-no-feed-prev.sh
做了对比实验

训练时候使用feed_prev==True效果稍好(红色) 特别是稳定性方差小一些

 
 

 
 

Pointer-network的tensorflow实现-1的更多相关文章

  1. Convolutional Neural Network in TensorFlow

    翻译自Build a Convolutional Neural Network using Estimators TensorFlow的layer模块提供了一个轻松构建神经网络的高端API,它提供了创 ...

  2. (转)The Road to TensorFlow

    Stephen Smith's Blog All things Sage 300… The Road to TensorFlow – Part 7: Finally Some Code leave a ...

  3. TensorFlow简易学习[3]:实现神经网络

    TensorFlow本身是分布式机器学习框架,所以是基于深度学习的,前一篇TensorFlow简易学习[2]:实现线性回归对只一般算法的举例只是为说明TensorFlow的广泛性.本文将通过示例Ten ...

  4. TensorFlow tutorial

    代码示例来自https://github.com/aymericdamien/TensorFlow-Examples tensorflow先定义运算图,在run的时候才会进行真正的运算. run之前需 ...

  5. 5个最好的TensorFlow网络课程

    1. Introduction to TensorFlow for Artificial Intelligence, Machine Learning and Deep Learning This c ...

  6. Recurrent Neural Network[Content]

    下面的RNN,LSTM,GRU模型图来自这里 简单的综述 1. RNN 图1.1 标准RNN模型的结构 2. BiRNN 3. LSTM 图3.1 LSTM模型的结构 4. Clockwork RNN ...

  7. 改善深层神经网络-week3编程题(Tensorflow 实现手势识别 )

    TensorFlow Tutorial Initialize variables Start your own session Train algorithms Implement a Neural ...

  8. Convolutional Neural Network-week1编程题(TensorFlow实现手势数字识别)

    1. TensorFlow model import math import numpy as np import h5py import matplotlib.pyplot as plt impor ...

  9. 吴恩达课后习题第二课第三周:TensorFlow Introduction

    目录 第二课第三周:TensorFlow Introduction Introduction to TensorFlow 1 - Packages 1.1 - Checking TensorFlow ...

随机推荐

  1. JavaScript中date 对象常用方法

    Date 对象 Date 对象用于处理日期和时间. //创建 Date 对象的语法: var datetime = new Date();//Date 对象会自动把当前日期和时间保存为其初始值. co ...

  2. UltralEdit 替换回车换行符

    打开 Ue 工具,写下内容,如下图: 然后按 Ctrl + r,输入 ^p,点击按钮 “全部替换”, 如下图:

  3. [原创]AndroBugs_Framework Android漏洞扫描器介绍

    [原创]AndroBugs_Framework Android漏洞扫描器介绍 1 AndroBugs_Framework Android 漏洞扫描器简介 一款高效的Android漏洞扫描器,可以帮助开 ...

  4. C# 实现Remoting双向通信

    本篇文章主要介绍了C# 实现Remoting双向通信,.Net Remoting 是由客户端通过Remoting,访问通道以获得服务端对象,再通过代理解析为客户端对象来实现通信的 闲来无事想玩玩双向通 ...

  5. 控制WinForm中Tab键的跳转

    一,需求 在Winform中,默认情况下,按下Tab键,光标会按照我们设定的TabIndex值从小到大进行跳转. 但如果用户要求按下Tab键跳转到特定的控件,这种要求还是很合理的,比如用户只想输入几个 ...

  6. 国际化之Android设备支持的语种

    昨天发了关于iOS支持的语种,文章最后也补了安卓支持语种列表.但最后发现安卓设备支持跟它列的有出入,我重新完全手工整理了一遍. 我将对应的语种在安卓的语言列表里的显示,也全部逐一列出来了,方便大家到时 ...

  7. Android如何实现茄子快传

    Android如何实现茄子快传茄子快传是一款文件传输应用,相信大家都很熟悉这款应用,应该很多人用过用来文件的传输.它有两个核心的功能: 端到端的文件传输Web端的文件传输这两个核心的功能我们具体来分析 ...

  8. 罗技Setpoint控制酷狗等第三方播放器

    手里有个淘过来的二手戴尔蓝牙键盘,虽然是戴尔的,但是确实罗技代工的,因此可以使用罗技的Setpoint,用这个软件后可以集中管理罗技的键盘鼠标进行一些个性化设置,如下图所示.不过郁闷的是如果不装Set ...

  9. MySQL DBA工作角色和职责介绍

    MySQL DBA分架构DBA,运维DBA和开发DBA三种角色,职责介绍如下:

  10. Webdings字体和Wingdings字体对照表

    一.Webdings是一个TrueType的dingbat字体,于1997年发表,搭载在其后的Microsoft Windows视窗系统内,大多的字形都没有Unicode的相对字. 使用很简单1.网页 ...