原文地址:

https://blog.csdn.net/qq_23981335/article/details/89097757

---------------------
作者:周卫林
来源:CSDN

-----------------------------------------------------------------------------------------------

1.构建LSTM
在tensorflow中,存在两个库函数可以构建LSTM,分别为tf.nn.rnn_cell.BasicLSTMCell和tf.contrib.rnn.BasicLSTMCell,最常使用的参数是num_units,表示的是LSTM中隐含状态的维度,state_in_tuple表示将(c,h)表示为一个元组。

lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size)

2.初始化隐含状态 
LSTM的输入不仅有数据输入,还有前一个时刻的状态输入,因此需要初始化输入状态

initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)

3.添加dropout层 
可以在基本的LSTM上添加dropout层

lstm_cell =  tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)

4.多层LSTM

cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*hidden_layer_num)

其中hidden_layer_num为LSTM的层数

5.完整代码

(1)原理表达最清楚、最一目了然的LSTM构建方式如下:

import tensorflow as tf
import numpy as np batch_size=2
hidden_size=64
num_steps=10
input_dim=8 input=np.random.randn(batch_size,num_steps,input_dim)
input[1,6:]=0
x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_steps,input_dim],name='input_x')
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size)
initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32) outputs=[]
with tf.variable_scope('RNN'):
for i in range(num_steps):
if i > 0 :
# print(tf.get_variable_scope())
tf.get_variable_scope().reuse_variables() output=lstm_cell(x[:,i,:],initial_state)
outputs.append(output) with tf.Session() as sess:
init_op=tf.initialize_all_variables()
sess.run(init_op) np.set_printoptions(threshold=np.NAN) result=sess.run(outputs,feed_dict={x:input})
print(result)

(2)简化构建形式

如果觉得写for循环比较麻烦,则可以使用tf.nn.static_rnn函数,这个函数就是使用for循环实现的LSTM ,但是需要注意的是该函数的参数设置:

tf.nn.static_rnn(
cell,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None
)

其中cell即为LSTM,inputs的维度必须为  [ num_steps,  batch_size,  input_dim ]  ,sequence_length为batch_size个输入的长度。

完整代码如下:

import tensorflow as tf
import numpy as np batch_size=2
num_units=64
num_steps=10
input_dim=8 input=np.random.randn(batch_size,num_steps,input_dim)
input[1,6:]=0
x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_steps,input_dim],name='input_x')
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units)
initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
y=tf.unstack(x,axis=1)
# x:[batch_size,num_steps,input_dim],type:placeholder
# y:[num_steps,batch_size,input_dim],type:list
output,state=tf.nn.static_rnn(lstm_cell,y,sequence_length=[10,6],initial_state=initial_state)
with tf.Session() as sess:
init_op=tf.initialize_all_variables()
sess.run(init_op) np.set_printoptions(threshold=np.NAN) result1,result2=(sess.run([output,state],feed_dict={x:input}))
result1=np.asarray(result1)
result2=np.asarray(result2)
print(result1)
print('*'*100)
print(result2)

还可以使用tf.nn.dynamic_rnn函数来实现

tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)

该函数的cell即为LSTM,inputs的维度是    [batch_size,num_steps,input_dim]

output,state=tf.nn.dynamic_rnn(cell,x,sequence_length=[10,6],initial_state=initial_state)

6、static_rnn与dynamic_rnn之间的区别
        不论dynamic_rnn还是static_rnn,每个batch的序列长度都是一样的(不足的话自己要去padding),不同的是dynamic会根据 sequence_length 中止计算。另外一个不同是dynamic_rnn动态生成graph 。
但是dynamic_rnn不同的batch序列长度可以不一样,例如第一个batch长度为10,第二个batch长度为20,但是static_rnn不同的batch序列长度必须是相同的,都必须是num_steps

下面使用dynamic_rnn来实现不同batch之间的序列长度不同:

import tensorflow as tf
import numpy as np batch_size=2
num_units=64
num_steps=10
input_dim=8 input=np.random.randn(batch_size,num_steps,input_dim)
input2=np.random.randn(batch_size,num_steps*2,input_dim) x=tf.placeholder(dtype=tf.float32,shape=[batch_size,None,input_dim],name='input') # None 表示序列长度不定
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units)
initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32) output,state=tf.nn.dynamic_rnn(lstm_cell,x,initial_state=initial_state) with tf.Session() as sess:
init_op=tf.initialize_all_variables()
sess.run(init_op) np.set_printoptions(threshold=np.NAN) result1,result2=(sess.run([output,state],feed_dict={x:input})) # 序列长度为10 x:[batch_size,num_steps,input_dim],此时LSTM个数为10个,或者说循环10次LSTM
result1=np.asarray(result1)
result2=np.asarray(result2)
print(result1)
print('*'*100)
print(result2) result1, result2 = (sess.run([output, state], feed_dict={x:input2})) # 序列长度为20 x:[batch_size,num_steps,input_dim],此时LSTM个数为20个,或者说循环20次LSTM
result1 = np.asarray(result1)
result2 = np.asarray(result2)
print(result1)
print('*' * 100)
print(result2)

但是static_rnn是不可以的。

7.dynamic_rnn的性能和static_rnn的性能差异

import tensorflow as tf
import numpy as np
import time num_step=100
input_dim=8
batch_size=2
num_unit=64 input_data=np.random.randn(batch_size,num_step,input_dim)
x=tf.placeholder(dtype=tf.float32,shape=[batch_size,num_step,input_dim])
seq_len=tf.placeholder(dtype=tf.int32,shape=[batch_size])
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_unit)
initial_state=lstm_cell.zero_state(batch_size,dtype=tf.float32) y=tf.unstack(x,axis=1)
output1,state1=tf.nn.static_rnn(lstm_cell,y,sequence_length=seq_len,initial_state=initial_state) output2,state2=tf.nn.dynamic_rnn(lstm_cell,x,sequence_length=seq_len,initial_state=initial_state) print('begin train...')
with tf.Session() as sess:
init_op=tf.initialize_all_variables()
sess.run(init_op) for i in range(100):
sess.run([output1,state1],feed_dict={x:input_data,seq_len:[10]*batch_size}) time1=time.time()
for i in range(100):
sess.run([output1,state1],feed_dict={x:input_data,seq_len:[10]*batch_size})
time2=time.time()
print('static_rnn seq_len:10\t\t{}'.format(time2-time1)) for i in range(100):
sess.run([output1,state1],feed_dict={x:input_data,seq_len:[100]*batch_size})
time3=time.time()
print('static_rnn seq_len:100\t\t{}'.format(time3-time2)) for i in range(100):
sess.run([output2,state2],feed_dict={x:input_data,seq_len:[10]*batch_size})
time4=time.time()
print('dynamic_rnn seq_len:10\t\t{}'.format(time4-time3)) for i in range(100):
sess.run([output2,state2],feed_dict={x:input_data,seq_len:[100]*batch_size})
time5=time.time()
print('dynamic_rnn seq_len:100\t\t{}'.format(time5-time4))

result:

static_rnn seq_len:10       0.8497538566589355
static_rnn seq_len:100 1.5897266864776611
dynamic_rnn seq_len:10 0.4857025146484375
dynamic_rnn seq_len:100 2.8693313598632812

序列短的要比序列长的运行的快,dynamic_rnn比static_rnn快的原因是:dynamic_rnn运行到序列长度后自动停止,不再运行,而static_rnn必须运行完num_steps才停止;序列长度为100的实验结果和分析相反,可能是因为循环耗时间,比不上直接在100个LSTM上运行的性能。

-----------------------------------------------------------------------------------------------

【转载】 LSTM构建步骤以及static_rnn与dynamic_rnn之间的区别的更多相关文章

  1. Java 中访问数据库的步骤?Statement 和PreparedStatement 之间的区别?

    Java 中访问数据库的步骤?Statement 和PreparedStatement 之间的区别? Java 中访问数据库的步骤 1)注册驱动: 2)建立连接: 3)创建Statement: 4)执 ...

  2. 【转载】 【TensorFlow】static_rnn 和dynamic_rnn的区别

    原文地址: https://blog.csdn.net/qq_20135597/article/details/88980975 ----------------------------------- ...

  3. Google分布式构建软件之三:分布式执行构建步骤

    注:本文英文原文在google开发者工具组的博客上[需要FQ],以下是我的翻译,欢迎转载,但请尊重作者版权,注名原文地址. 之前两篇文章分别介绍了Google 分布式软件构建系统Blaze相关的为了提 ...

  4. 在TensorFlow中基于lstm构建分词系统笔记

    在TensorFlow中基于lstm构建分词系统笔记(一) https://www.jianshu.com/p/ccb805b9f014 前言 我打算基于lstm构建一个分词系统,通过这个例子来学习下 ...

  5. TeamCity 创建jar构建步骤

    1 创建工程 2 配置工程代码来源信息 2.1 From a repository URL 表示从代码仓库创建工程. 2.1.1 parent project 指定父工程,默认是root projec ...

  6. 自定义Qt构建步骤,添加数据文件(txt,json等)到构建目录

    Qt的qrc资源文件是只读的,因此我们如果要用txt之类的文件存储数据,在程序运行过程中就不能对它们进行修改,也就是不能进行读操作.用"file.open(QIODevice::WriteO ...

  7. jenkins检查代码,如没更新停止构建步骤

    需求分析 在jenkins中没有找到构建前插件,每次构建时间很长,希望可以实现判断代码是否更新,如果没更细则停止构建步骤. 实现步骤 在构建时执行shell命令,而jenkins提供的的环境变量可以实 ...

  8. SpringCloud学习笔记(四):Eureka服务注册与发现、构建步骤、集群配置、Eureka与Zookeeper的比较

    简介 Netflix在设计Eureka时遵守的就是AP原则 拓展: 在分布式数据库中的CAP原理 CAP原则又称CAP定理,指的是在一个分布式系统中,Consistency(一致性). Availab ...

  9. 转载:详细解析Java中抽象类和接口的区别

    在Java语言中, abstract class 和interface 是支持抽象类定义的两种机制.正是由于这两种机制的存在,才赋予了Java强大的 面向对象能力.abstract class和int ...

随机推荐

  1. Web开发之跨域问题

    最近在工作上遇到了跨域方面的问题,借此温习巩固. 跨域是受到浏览器的同源策略引起的,为了防止某些文档或脚本加载别的域下的未知内容造成泄露隐私,破坏系统等安全行为. 那什么是同源的呢? 同源是指:应用协 ...

  2. Spring 重定向(Redirect)指南

    原文:Hacking the IntegerCache in Java 9? 链接:https://dzone.com/articles/hacking-the-integercache-in-jav ...

  3. 25.centos7基础学习与积累-011-课前考试二-命令练习

    从头开始积累centos7系统运用 大牛博客:https://blog.51cto.com/yangrong/p5 取IP地址: 6的命令:ifconfig eth0 7的命令 [root@pytho ...

  4. C 是什么样的语言?

    学习交流可加 微信读者交流①群 (添加微信:coderAllen) 程序员技术QQ交流①群:736386324 --- ==C 是什么样的语言?== 这个问题不要急于寻找问题的答案,而是应该先去考虑当 ...

  5. 关于ssh_config和sshd_config

    转载:https://www.cnblogs.com/panda2046/p/5933498.html   在远程管理linux系统基本上都要使用到ssh,原因很简单:telnet.FTP等传输方式是 ...

  6. Mysql InnoDB行锁不使用索引锁表的时候会锁整张表

    原文:http://www.thinkphp.cn/topic/41577.html 如果使用针对InnoDB的表使用行锁,被锁定字段不是主键,也没有针对它建立索引的话.行锁锁定的也是整张表.锁整张表 ...

  7. ELK日志分析系统搭建 windows

    1 分别下载elk包 下载地址 https://www.elastic.co/cn/downloads 2 将这三个解压到同一个目录下,便于管理 3 elasticsearch不需要修改配置 默认即可 ...

  8. Pollard-rho的质因数分解

    思路:见参考文章(原理我是写不粗来了) 代码: 用到了快速幂,米勒罗宾素性检验. #include <iostream> #include <time.h> #include ...

  9. Serializable的作用

    前两天接触到VO,DTO,entity这些概念,发现别人的代码中会有 implements serializable这个东西,之前并没有见过这种写法,就去了解了一下原因 import java.io. ...

  10. centos7部署postgresql集群高可用 patroni + etcd 之patroni篇

    实验环境:centos7.4纯净版 postgres版本: 9.6.15 etcd版本:3.3.11 patroni版本:1.6.0 patroni介绍可参考:https://github.com/z ...