github:https://github.com/zle1992/Seq2Seq-Chatbot

1、 注意在infer阶段,需要需要reuse,

2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:

  • The encoder output has been tiled to beam_width via tf.contrib.seq2seq.tile_batch (NOT tf.tile).
  • The batch_size argument passed to the zero_state method of this wrapper is equal to true_batch_size * beam_width.
  • The initial state created with zero_state above contains a cell_state value containing properly tiled final state from the encoder.
  1. import tensorflow as tf
  2. from tensorflow.python.layers.core import Dense
  3.  
  4. BEAM_WIDTH = 5
  5. BATCH_SIZE = 128
  6.  
  7. # INPUTS
  8. X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
  9. Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
  10. X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
  11. Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
  12.  
  13. # ENCODER
  14. encoder_out, encoder_state = tf.nn.dynamic_rnn(
  15. cell = tf.nn.rnn_cell.BasicLSTMCell(128),
  16. inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
  17. sequence_length = X_seq_len,
  18. dtype = tf.float32)
  19.  
  20. # DECODER COMPONENTS
  21. Y_vocab_size = 10000
  22. decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
  23. projection_layer = Dense(Y_vocab_size)
  24.  
  25. # ATTENTION (TRAINING)
  26. with tf.variable_scope('shared_attention_mechanism'):
  27. attention_mechanism = tf.contrib.seq2seq.LuongAttention(
  28. num_units = 128,
  29. memory = encoder_out,
  30. memory_sequence_length = X_seq_len)
  31.  
  32. decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
  33. cell = tf.nn.rnn_cell.BasicLSTMCell(128),
  34. attention_mechanism = attention_mechanism,
  35. attention_layer_size = 128)
  36.  
  37. # DECODER (TRAINING)
  38. training_helper = tf.contrib.seq2seq.TrainingHelper(
  39. inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
  40. sequence_length = Y_seq_len,
  41. time_major = False)
  42. training_decoder = tf.contrib.seq2seq.BasicDecoder(
  43. cell = decoder_cell,
  44. helper = training_helper,
  45. initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
  46. output_layer = projection_layer)
  47. with tf.variable_scope('decode_with_shared_attention'):
  48. training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
  49. decoder = training_decoder,
  50. impute_finished = True,
  51. maximum_iterations = tf.reduce_max(Y_seq_len))
  52. training_logits = training_decoder_output.rnn_output
  53.  
  54. # BEAM SEARCH TILE
  55. encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
  56. X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
  57. encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH)
  58.  
  59. # ATTENTION (PREDICTING)
  60. with tf.variable_scope('shared_attention_mechanism', reuse=True):
  61. attention_mechanism = tf.contrib.seq2seq.LuongAttention(
  62. num_units = 128,
  63. memory = encoder_out,
  64. memory_sequence_length = X_seq_len)
  65.  
  66. decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
  67. cell = tf.nn.rnn_cell.BasicLSTMCell(128),
  68. attention_mechanism = attention_mechanism,
  69. attention_layer_size = 128)
  70.  
  71. # DECODER (PREDICTING)
  72. predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
  73. cell = decoder_cell,
  74. embedding = decoder_embedding,
  75. start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
  76. end_token = 2,
  77. initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
  78. beam_width = BEAM_WIDTH,
  79. output_layer = projection_layer,
  80. length_penalty_weight = 0.0)
  81. with tf.variable_scope('decode_with_shared_attention', reuse=True):
  82. predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
  83. decoder = predicting_decoder,
  84. impute_finished = False,
  85. maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
  86. predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0]
  87.  
  88. print('successful')

参考:

https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084

https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder

https://github.com/tensorflow/nmt#beam-search

Tensorflow --BeamSearch的更多相关文章

  1. tensorflow 笔记13:了解机器翻译,google NMT,Attention

    一.关于Attention,关于NMT 未完待续... 以google 的 nmt 代码引入 探讨下端到端: 项目地址:https://github.com/tensorflow/nmt 机器翻译算是 ...

  2. Effective Tensorflow[转]

    Effective TensorFlow Table of Contents TensorFlow Basics Understanding static and dynamic shapes Sco ...

  3. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  4. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  5. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  6. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  7. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  8. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  9. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

随机推荐

  1. SQLServer 2014 内存优化表

    内存优化表是 SQLServer 2014 的新功能,它是可以将表放在内存中,这会明显提升DML性能.关于内存优化表,更多可参考两位大侠的文章:SQL Server 2014新特性探秘(1)-内存数据 ...

  2. React Component Lifecycle(生命周期)

    生命周期 所谓生命周期,就是一个对象从开始生成到最后消亡所经历的状态,理解生命周期,是合理开发的关键.RN 组件的生命周期整理如下图: 如图,可以把组件生命周期大致分为三个阶段: 第一阶段:是组件第一 ...

  3. PostgreSQL自学笔记:9 索引

    9 索引 9.1 索引简介 索引是对数据库表中一列或多列值进行排序的一种结构,使用 索引可提高数据库中特定数据的查询速度 9.1.1 索引的含义和特点 索引是一种单独的.存储在磁盘上的数据库结构,他们 ...

  4. markdown改变字体颜色和大小

    markdown中改变字体颜色与大小方法同html 先看例子 <font face="黑体">我是黑体字</font> 我是黑体字 <font fac ...

  5. C++ 控制台推箱子小游戏

              // 游戏菜单.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include<iostream> #in ...

  6. Servlet.service() for servlet [jsp] in context with path [/Healthy_manager] threw exception [Unable to compile class for JSP] with root cause java.lang.IllegalArgumentException: Page directive: inval

    严重: Servlet.service() for servlet [jsp] in context with path [/Healthy_manager] threw exception [Una ...

  7. Selenium 3 学习小结

    4个类+常用的46个方法 从以下知识内容对selenium 3自动化框架进行初步学习: 1.安装selenium pip install selenium pip list 2.驱动.关闭浏览器 首先 ...

  8. vue发送请求----vue-resource

    使用插件vue-resource 官方提供的接口,在vue官网找不到 但在github中可以找到 安装:cnpm install vue-resource --save 第一步:注意要加--save, ...

  9. 使用ajax分页查询

    controller: /** * 查询所有用户/查找指定用户 * 分页+搜索 * */@RequestMapping("/findClientBySize")@ResponseB ...

  10. Mybatis获取传参

    取自  https://blog.csdn.net/weixin_38303684/article/details/78886375 mybatis中SQL接受的参数分为:(1)基本类型(2)对象(3 ...