本文从tensorflow的代码层面理解LSTM。

看本文之前,需要先看我的这两篇博客

https://www.cnblogs.com/yanshw/p/10495745.html 谈到网络结构

https://www.cnblogs.com/yanshw/p/10515436.html 谈到多隐层神经网络

回忆一下LSTM网络

输出

tensorflow 用 tf.nn.dynamic_rnn构建LSTM的输出

  1. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_number,forget_bias = 1.0)
  2. # 初始化s
  3. init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32)
  4. outputs,states = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=init_state,time_major=False)

output,states 都是隐层的输出,注意只是隐层h,还没到o,o=vh+c(c是bias,不同于states的c,这是记忆单元)

output 是以三维矩阵形式记录了 所有样本所有时刻所有隐层的输出,shape 为 [batch_size, timestep_size, hidden_size]

states 是所有样本最后时刻所有隐层的 c 和 h,c 是记忆单元, states的shape 为 [2, batch_size, hidden_size] ,2表示 c 和 h

 states[1] == outputs[:,-1,:] == h

图形表示如下

多隐层

1. MultiRNNCell 构建多隐层LSTM,输出同 tf.nn,dynamic_rnn

2. 多隐层 h0 的shape

3. 多隐层 的输出

  1. # encoding:utf-8
  2. __author__ = 'HP'
  3. import tensorflow as tf
  4.  
  5. # 时序为1
  6.  
  7. batch_size=10
  8. depth=128 # 特征数
  9.  
  10. inputs=tf.Variable(tf.random_normal([batch_size,depth])) #
  11.  
  12. # 多隐层的h0
  13. previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
  14. previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
  15. previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))
  16.  
  17. num_units=[100,200,300] # 隐层神经元个数
  18. print(inputs)
  19.  
  20. cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]
  21. mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)
  22.  
  23. # MultiRNNCell 直接输出
  24. outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))
  25.  
  26. print(outputs.shape) #(10, 300)
  27. print(states[0]) #第一层LSTM
  28. print(states[1]) #第二层LSTM
  29. print(states[2]) ##第三层LSTM
  30. print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
  31. print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
  32. print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)
  1. 网络构建
  1. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)
  2. mlstm_cell = rnn.MultiRNNCell([clstm() for i in range(layer_num)], state_is_tuple=True)
  3. outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
  4.  
  5. vs
  6.  
  7. mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)
  8. outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))

h0 shape

之前讲到h0的shape是 [batch_size, hidden_size],只是针对单隐层的

多隐层应该是 [batch_size, hidden1_size]  + [batch_size, hidden2_size] + [batch_size, hidden3_size] + ...

上例中用 MultiRNNCell 构建LSTM, 需要同时定义 c0 和  h0,且二者shape相同,故

  1. previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
  2. previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
  3. previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))

图形表示

多隐层输出

单隐层输出本文最开头已经讲了,多隐层会稍有不同

output 仍然是 [batch_size, timestep_size, hidden_size]

而 states 是 [n_layer, 2, batch_size, hidden_size]

  1. print(outputs.shape) #(10, 300)
  2. print(states[0]) #第一层LSTM
  3. print(states[1]) #第二层LSTM
  4. print(states[2]) ##第三层LSTM
  5. print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
  6. print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
  7. print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)

图形表示

故 state[-1][1] == outputs[:, -1, :] == h

  1.  

tensorflow-LSTM-网络输出与多隐层节点的更多相关文章

  1. 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:使用TensorFlow和Keras开发高级自然语言处理系统——LSTM网络原理以及使用LSTM实现人机问答系统

    !mkdir '/content/gdrive/My Drive/conversation' ''' 将文本句子分解成单词,并构建词库 ''' path = '/content/gdrive/My D ...

  2. RNN,LSTM中如何使用TimeDistributed包装层,代码示例

    本文介绍了LSTM网络中的TimeDistributed包装层,代码演示了具有TimeDistributed层的LSTM网络配置方法. 演示了一对一,多对一,多对多,三种不同的预测方法如何配置. 在对 ...

  3. 使用tensorflow的lstm网络进行时间序列预测

    https://blog.csdn.net/flying_sfeng/article/details/78852816 版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog. ...

  4. Tensorflow进行POS词性标注NER实体识别 - 构建LSTM网络进行序列化标注

    http://blog.csdn.net/rockingdingo/article/details/55653279  Github下载完整代码 https://github.com/rockingd ...

  5. 神经网络结构设计指导原则——输入层:神经元个数=feature维度 输出层:神经元个数=分类类别数,默认只用一个隐层 如果用多个隐层,则每个隐层的神经元数目都一样

    神经网络结构设计指导原则 原文   http://blog.csdn.net/ybdesire/article/details/52821185   下面这个神经网络结构设计指导原则是Andrew N ...

  6. 循环神经网络与LSTM网络

    循环神经网络与LSTM网络 循环神经网络RNN 循环神经网络广泛地应用在序列数据上面,如自然语言,语音和其他的序列数据上.序列数据是有很强的次序关系,比如自然语言.通过深度学习关于序列数据的算法要比两 ...

  7. LSTM网络(Long Short-Term Memory )

    本文基于前两篇 1. 多层感知机及其BP算法(Multi-Layer Perceptron) 与 2. 递归神经网络(Recurrent Neural Networks,RNN) RNN 有一个致命的 ...

  8. Tensorflow[LSTM]

    0.背景 通过对<tensorflow machine learning cookbook>第9章第3节"implementing_lstm"进行阅读,发现如下形式可以 ...

  9. (译)理解 LSTM 网络 (Understanding LSTM Networks by colah)

    @翻译:huangyongye 原文链接: Understanding LSTM Networks 前言:其实之前就已经用过 LSTM 了,是在深度学习框架 keras 上直接用的,但是到现在对LST ...

随机推荐

  1. gcc优化引起get_free_page比__get_free_page返回值多4096

    2017-12-12 18:53:04 gcc优化引起get_free_page比__get_free_page返回值多4096 内核版本:1.3.100 extern inline unsigned ...

  2. centos 安装 和 linux 简单命令

    1. centos 安装 参照:https://www.cnblogs.com/tiger666/articles/10259102.html 安装过程注意点: 1. 安装过程中的选择安装Basic ...

  3. STLC - 软件测试生命周期

    什么是软件测试生命周期(STLC)? 软件测试生命周期(STLC)定义为执行软件测试的一系列活动. 它包含一系列在方法上进行的活动,以帮助认证您的软件产品. 图 - 软件测试生命周期的不同阶段 每个阶 ...

  4. 6月17 ThinkPHP连接数据库------数据的修改及删除

    1.数据修改操作 save()  实现数据修改,返回受影响的记录条数 具体有两种方式实现数据修改,与添加类似(数组.AR方式) 1.数组方式 a)         $goods = D(“Goods” ...

  5. java压缩流

    java压缩流是为了减少传输时的数据量,可以将文件压缩成ZIP.JAR.GZIP等文件格式.

  6. 不安装Oracle数据库使用plsqldevloper

    1.Oracle官网下载instantclient 解压到D:\zl\instantclient_11_2 2.配置环境变量 ORACLE_HOME = D:\zl\instantclient_11_ ...

  7. 学习Spring Security OAuth认证(一)-授权码模式

    一.环境 spring boot+spring security+idea+maven+mybatis 主要是spring security 二.依赖 <dependency> <g ...

  8. 水题系列一:Circle

    问题描述:Circle 小明在玩游戏,他正在玩一个套圈圈的游戏.他手里有 L 种固定半径的圆圈,每一种圆 圈都有其固定的数量.他要把这些圆圈套进 N 个圆形槽中的一个.这些圆形槽都有一个最 小半径和最 ...

  9. NOSQL -- Mongodb的简单操作与使用(wins)

    NOSQL -- Mongodb的简单操作与使用(wins) 启动mongodb: 1.首先启动服务 dos命令下:net start Mongndb 也可以查询服务,手动开启服务: 完成后: 2.启 ...

  10. Hadoop---hu-hadoop1: mv: cannot stat `/home/bigdata/hadoop-2.6.0/logs/hadoop-root-datanode-hu-hadoop1.out.4': No such file or directory

    hu-hadoop1: mv: cannot stat `/home/bigdata/hadoop-2.6.0/logs/hadoop-root-datanode-hu-hadoop1.out.4': ...