0x00 前言

常用的LSTM,或是双向LSTM,输出的结果通常是以下两个:
1) outputs,包括所有节点的hidden
2) 末节点的state,包括末节点的hidden和cell
大部分任务有这些就足够了,state是随着节点间信息的传递依次变化并容纳更多信息,
所以通常末状态的cell就囊括了所有信息,不需要中间每个节点的cell信息,
但如果我们的研究过程中需要用到这些cell该如何是好呢?

近期的任务中,需要每个节点的前后节点cell信息来做某种判断,
所以属于一个较为特殊的任务,自主实现了一下这个同样也会反馈cell的LSTM,
哦顺带一提Cell-Holding,是强行为了简称成CHD取的名字(笑)

0x01 分析与设计

首先分析源码,看一下通常LSTM层调用使用 dynamic_rnn 的实现逻辑,
原逻辑大概是这样的:

1
2
3
4
5
6
outputs = []
state = Cell.zero_state(N, tf.float32) # state = (hidden, cell)
for input in inputs:
trueoutput, state = Cell(input, state) # hidden, (hidden, cell) = Cell()
trueoutputs.append(output) # outputs.append(hidden)
return outputs, state # outputs := a list of (hidden)

那么其实……我们只需要重新实现一个简化的版本,让cell留下来即可。
此处使用的逻辑大概是这样的:

1
2
3
4
5
6
states_case = []
state = Cell.zero_state(N, tf.float32) # state = (hidden, cell)
for input in inputs:
trueoutput, state = Cell(input, state) # hidden, (hidden, cell) = Cell()
trueoutputs.append(output) # states_case.append((hidden, cell))
return states_case # states_case := list of (hidden, cell)

为了实现这些,就需要做到以下几件事情:
1) 获取或共享已有LSTM层的BasicLSTMCell
2) 编写Cell相关计算,保留LSTM计算途中的信息,可自定义获取输出的格式
3) 采用设计的输出格式使用这些节点信息,以完成其他任务

0x02 Source Code

Advanced LSTM Layer

[LstmLayer] in tf_layers
首先要在不影响功能的情况下改写原有的LSTM Layer,令其支持获取BasicCell的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class LstmLayer(object):
true# based on LSTM Layer, thanks for @lhw446
def __init__(self, input_dim, num_units, sequence_length=None, bidirection=False, name="lstm"):
self.input_dim = input_dim
self.num_units = num_units
self.bidirection = bidirection
self.sequence_length = sequence_length
self.name = name
# `with ... as...` remains assignment work.
self.lstm_fw_cell = None
self.lstm_bw_cell = None
with tf.name_scope('%s_def' % (self.name)):
self.lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)
if self.bidirection:
self.lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)
def __call__(self, inputs, sequence_length=None, time_major=False,
initial_state_fw=None, initial_state_bw=None):
inputs_shape = tf.shape(inputs)
inputs = tf.reshape(inputs, [-1, inputs_shape[-2], self.input_dim])
sequence_length = self.sequence_length if sequence_length is None
else tf.reshape(sequence_length, [-1])
if initial_state_fw is not None:
initial_state_fw = tf.nn.rnn_cell.LSTMStateTuple(
tf.reshape(initial_state_fw[0], [-1, self.num_units]),
tf.reshape(initial_state_fw[1], [-1, self.num_units]))
if initial_state_bw is not None:
initial_state_bw = tf.nn.rnn_cell.LSTMStateTuple(
tf.reshape(initial_state_bw[0], [-1, self.num_units]),
tf.reshape(initial_state_bw[1], [-1, self.num_units]))
resh_1 = lambda tensors: tf.reshape(
tensors, tf.concat([inputs_shape[:-1], [tf.shape(tensors)[-1]]], 0))
resh_2 = lambda tensors: tf.reshape(
tensors, tf.concat([inputs_shape[:-2], [tf.shape(tensors)[-1]]], 0))
with tf.variable_scope('%s_cal' % (self.name)):
if self.bidirection:
outputs, output_states = tf.nn.bidirectional_dynamic_rnn(
self.lstm_fw_cell, self.lstm_bw_cell, inputs,
sequence_length=sequence_length,
initial_state_fw=initial_state_fw,
initial_state_bw=initial_state_bw,
time_major=time_major, dtype=tf.float32)
# (fw_outputs, bw_outputs)
大专栏  【TensorFlow】自主实现包含全节点Cell的LSTM层 Cell> outputs = tf.nn.rnn_cell.LSTMStateTuple(resh_1(outputs[0]), resh_1(outputs[1]))
# ((fw_c_states, fw_m_states), (bw_c_states, bw_m_states))
output_states = tf.nn.rnn_cell.LSTMStateTuple(
tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[0][0]), resh_2(output_states[0][1])),
tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[1][0]), resh_2(output_states[1][1])))
else:
outputs, output_states = tf.nn.dynamic_rnn(
self.lstm_fw_cell, inputs, sequence_length=sequence_length,
initial_state=initial_state_fw,
time_major=time_major, dtype=tf.float32)
outputs = resh_1(outputs) # (outputs)
# (c_states, m_states)
output_states = tf.nn.rnn_cell.LSTMStateTuple(
resh_2(output_states[0]), resh_2(output_states[1]))
return outputs, output_states

Cell-HolDing Layer

chd_lstm_layer in network
然后基于目标LSTM层,构建使用相同基本单元的scope,设定初始零状态,逐层计算
(此处仅剪枝了所有的padding位,没有特意做加速,用了简单的python-like的for循环)
(且为了本次实验需要,没有将hidden和cell区分开来,而是直接保存了state整体,可自行修改)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def chd_lstm_layer(self, inputs, target_layer):
cell = target_layer.lstm_fw_cell
with tf.variable_scope('%s_cal' % (target_layer.name)):
# generate initial states for current inputs
states_case = []
for batch_idx in range(self.batch_size):
batch_state_case = []
state = cell.zero_state(1, tf.float32)
for time_step in range(self.seg_len[batch_idx]):
tf_input = inputs[batch_idx, time_step]
output, _state = cell(
tf.reshape(tf_input, [1, -1]), state)
batch_state_case.append(_state)
state = _state
states_case.append(batch_state_case)
# a nested list of states [batch_size, seg_len]
return states_case, cell

上述是任务需要,
主要演示了可以简单的循环调用给定LSTM层的Cell进行计算,
在对齐的情况下还可以通过stack等操作拼成一个tf的矩阵使用。
其中用作循环迭代次数的参数 self.batch_size self.seg_len等,
不可以是tf.placeholder,因为range内必须为一个固定的数值而不能为一个占位符(tf.loop不知道能不能做到)
所以在feed_dict前,我做了如下的操作,将这些固定数值作为 instance_variables 传给网络以供使用。

1
2
3
4
5
6
7
8
9
10
11
12
def gen_infer_inputs(self, data):
# data = merge_by_batch_size(batch_data_generate(data))
self.batch_size = data['cell_lens'].shape[0]
self.seg_len = data['cell_lens']
self.can_len = data['candi_mask'].sum(-1)
return {
self.input_data: data['input_data'],
self.cell_lens: data['cell_lens'],
self.candidates: data['candidates'],
self.candi_mask: data['candi_mask'],
self.keep_prob: 1.0,
}

Further usage on states_case

others_layer in network
获取了states_case之后,可以用于各个位置的使用
下文中给出一个使用案例,此处用于计算相同LSTM序列中,替换其中任意节点为其他节点的输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def replace_layer(self, forward_emb, candidate_emb):
backward_emb = self.get_reverse(forward_emb, rev_length=self.cell_lens + 2)
fw_states, fw_cell = self.chd_lstm_layer(
forward_emb, self.forward_lstm)
bw_states, bw_cell = self.chd_lstm_layer(
backward_emb, self.backward_lstm)
hidden_case = []
for batch_idx in range(self.batch_size):
batch_case = []
for time_step in range(self.seg_len[batch_idx]):
time_case = []
for candidate_idx in range(self.can_len[batch_idx, time_step]):
tf_input = candidate_emb[batch_idx, time_step, candidate_idx]
fw_hidden, _ = fw_cell(
tf.reshape(tf_input, [1, -1]),
fw_states[batch_idx][time_step])
bw_hidden, _ = bw_cell(
tf.reshape(tf_input, [1, -1]),
bw_states[batch_idx][-time_step])
hidden = tf.concat([fw_hidden, bw_hidden], -1)
time_case.append(hidden)
batch_case.append(time_case)
hidden_case.append(batch_case)
return hidden_case # a nested list.

0x03 后记

cell因其持续更新且后者包含前者信息的特性通常不被保存,
但是 LSTMCell RNNCell 的调用却需要完整的state(包括hiddencell),
在我们对已经计算完毕的LSTM序列中内部的某些节点有所想法时,就很难回溯了,
所以说不定这种layer也是有一定价值的,目前tensorflow里还没有整合成类似的层,
所以自行手写了一个,虽说不是太复杂,不过提供了这样一种想法,记录一下~
(说不定以后就加了这个层呢~ 到时候我可以指着这篇文章说我早就想到咯^_^)

【TensorFlow】自主实现包含全节点Cell的LSTM层 Cell的更多相关文章

  1. 查看tensorflow pb模型文件的节点信息

    查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...

  2. jQuery 获取当前节点的html包含当前节点的方法

    在开发过程中,jQuery.html() 是获取当前节点下的html代码,并不包含当前节点本身的代码,然后我们有时候确需要,找遍jQuery api文档也没有任何方法可以拿到. 看到有的人通过pare ...

  3. jquery 获取 outerHtml 包含当前节点本身的代码

    在开发过程中,jQuery.html() 是获取当前节点下的html代码,并不包含当前节点本身的代码,然后我们有时候确需要,找遍jQuery api文档也没有任何方法可以拿到. 看到有的人通过pare ...

  4. 比特币全节点(bitcoind) eth 全节点

    运行全节点的用途:  1.挖矿  2.钱包   运行全节点,可以做关于btc的任何事情,例如创建钱包地址.管理钱包地址.发送交易.查询全网的交易信息等等 选个节点钱包:bitcoind 1.配置文件: ...

  5. 以太坊geth主网全节点部署

    以太坊geth主网全节点部署 #环境 ubuntu 16.4 #硬盘500GB(目前占用200G) #客户端安装 # 查看下载页面最新版 # https://ethereum.github.io/go ...

  6. 比特币BTC全节点搭建

    比特币BTC全节点搭建 #环境 ubuntu 16.4 #硬盘500GB #截止2018-12-31磁盘占用超过230GB #客户端安装 #下载页面 #https://bitcoin.org/zh_C ...

  7. 以太坊go-ethereum客户端(三)两种全节点启动模式

    这篇博客介绍一下go-ethereum全节点的两种启动模式:主网络快速启动和测试网络快速启动.这也是客户端所提供的两种启动方式,直接拿来使用即可.下面具体介绍一下使用方法. 主网络快速启动 其实,我们 ...

  8. 100万套PPT模板,包含全宇宙所有主题类型PPT,绕宇宙100圈,持续更新

    100万套PPT模板,包含全宇宙所有主题类型PPT(全部免费,都是精品,没有一张垃圾不好看的PPT,任何一张PPT拿来套入自己的信息就可以立马使用),绕宇宙100圈,任意一个模板在某文库上都价不菲.强 ...

  9. JS获取包含当前节点本身的代码内容(outerHtml)

    原生JS DOM的内置属性 outerHTML 可用来获取当前节点的html代码(包含当前节点),且此属性可使用jQuery的prop()获取 <div id="demo-test-0 ...

随机推荐

  1. 题解-------P4053 [JSOI2007]建筑抢修

    传送门 贪心+左偏树 贪心思路:先修快炸的楼 所以我们可以按照$T2$从大到小做一遍排序,然后从$1\cdots n$一个一个去修,如果这栋楼不能修(也就是当前时间已经超过$T2_{i}$),那我们就 ...

  2. 基于JSP+Servlet开发在线租车系统 java 源码

    运行环境: 最好是java jdk 1.8,我们在这个平台上运行的.其他版本理论上也可以.IDE环境: Eclipse,Myeclipse,IDEA都可以tomcat环境: Tomcat 7.x,8. ...

  3. hdu2896&&3065

    题:http://acm.hdu.edu.cn/showproblem.php?pid=2896 分析:ac自动机模板 注意细节,1.128个ascii码都要: 2.只要关键码含有只输出一个编号就行 ...

  4. 容斥原理的(二进制思想和质因子分解+模板)hdu4135+ecf81.D

    题:http://acm.hdu.edu.cn/showproblem.php?pid=4135 题意:求[A,B]与N互质的数的个数 #include<iostream> #includ ...

  5. 线程池-进程池-io模型

    一.线程池与进程池 什么是池?简单的说就是一个容器,一个范围 在保证计算机硬件安全的情况下最大限度的充分利用计算机, 池其实是降低了程序的运行效率,但是保证了计算机硬件的安全,也是实现了一个并发的效果 ...

  6. Mysql数据库删除重复数据

    最近因为发现数据库中的表有脏数据,需要维护.这些脏数据就是重复数据,需要将其删除. 现假设有一张test表,主键字段为num,还有id,one,two三个字段.假设id规定只能有一条记录(即需要为id ...

  7. Jupyter_Notebook

    TA-lib指标库地址 http://github.com/xingbuxing/TA-Lib-in-chinese 1.Jupter是基于网页端写代码,属于一种交互式的编程,除了在上面写代码之外还可 ...

  8. [LC] 207. Course Schedule

    There are a total of n courses you have to take, labeled from 0 to n-1. Some courses may have prereq ...

  9. 使用 Wintersmith + Serverless Framework 快速创建个人站点

    首先我们来介绍下,Wintersmith 是一个简单而灵活的静态站点生成器.采用 markdown 构建,这个是我们的基础条件. Serverless Framework:在 GitHub 上有三万颗 ...

  10. scala编程(八)——函数和闭包

    当程序变得庞大时,你需要一些方法把它们分割成更小的,更易管理的片段.为了分割控制流,Scala 提供了所有有经验的程序员都熟悉的方式:把代码分割成函数.实际上,Scala 提供了许多 Java 中没有 ...