TF_RNNCell
参考:链接。
RNNCell
- BasicRNNCell
- GRUCell
- BasicLSTMCell
- LSTMCell
- MultiRNNCell
抽象类RNNCell
所有的rnncell均继承于RNNCell, RNNCell主要定义了几个抽象方法:
- def __call__(self, inputs, state, scope=None):
- raise NotImplementedError("Abstract method")
- @property
- def state_size(self):
- raise NotImplementedError("Abstract method")
- @property
- def output_size(self):
- raise NotImplementedError("Abstract method")
上述方法,__call__
在对象被使用时调用,其他可以看做属性方法,主要用作获取状态state的大小,cell的输出大小。既然对象使用时会调用__call__
,那么各类RNN的操作都定义在这个方法中。接下来,我们就针对各个不同的cell来详细介绍各类RNN。
BasicRNNCell
这个cell是最基础的一个RNNCell,可以看做是对一般全连接层的拓展,除了在水平方向加入时序关系,可以用下图表示:
而BasicRNNCell的初始化方法可如代码所示:
- def __init__(self, num_units, input_size=None, activation=tanh):
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._activation = activation
初始化只需要给出num_units
,用来指有多少个隐藏层单元;而activation
指使用哪种激活函数用作激活输出。而对应的RNN操作定义在__call__
方法中:
- def __call__(self, inputs, state, scope=None):
- """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
- with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
- output = self._activation(_linear([inputs, state], self._num_units, True))
- return output, output
很清晰,inputs
表示隐藏层的输入,state
表示上个时间的隐藏层状态,也可以说是上一次隐藏层向自身的输出,对于第一次输入,则需要初始化state,
对应初始化方法有很多种,可以使用tensorflow提供的各种初始化函数。在__call__
中,对输入inputs
和state
进行activation(wx+b),
用作下次的输入。
GRUCell
GRU是对RNN的一种改进,相比LSTM来说,也可以看做是对LSTM的一种简化,是Bengio在14年提出来的,用作机器翻译。先看一下GRU的基本结构:
这里我们结合代码来看原理:
- def __call__(self, inputs, state, scope=None):
- """Gated recurrent unit (GRU) with nunits cells."""
- with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
- with vs.variable_scope("Gates"): # Reset gate and update gate.
- # We start with bias of 1.0 to not reset and not update.
- r, u = array_ops.split(1, 2, _linear([inputs, state],
- 2 * self._num_units, True, 1.0))
- r, u = sigmoid(r), sigmoid(u)
- with vs.variable_scope("Candidate"):
- c = self._activation(_linear([inputs, r * state],
- self._num_units, True))
- new_h = u * state + (1 - u) * c
- return new_h, new_h
GRUCell的初始化与RNN一样,给出输入和初始化的state,在使用对象时,利用输入和前一个时间的隐藏层状态,得到对应的Gates
: r, u, 然后利用r更新cell状态,最后利用u得到新的隐藏层状态。对于RNN的改进,最厉害的莫过于下面的,而且有很多变种,这里tensorflow中只有几个简单常见的cell。接下来,我们开始看看LSTM。
BasicLSTMCell
这个cell可以看做是最简单的LSTM,在每个连接中没有额外的连接,即其他变种在连接中加入各种改进。对于BasicLSTMCell,可以如下图所示:
同样的,我们结合代码来看它的原理:
- def __call__(self, inputs, state, scope=None):
- """Long short-term memory cell (LSTM)."""
- with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
- # Parameters of gates are concatenated into one multiply for efficiency.
- if self._state_is_tuple:
- c, h = state
- else:
- c, h = array_ops.split(1, 2, state)
- concat = _linear([inputs, h], 4 * self._num_units, True)
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- i, j, f, o = array_ops.split(1, 4, concat)
- new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
- self._activation(j))
- new_h = self._activation(new_c) * sigmoid(o)
- if self._state_is_tuple:
- new_state = LSTMStateTuple(new_c, new_h)
- else:
- new_state = array_ops.concat(1, [new_c, new_h])
- return new_h, new_state
lstm有三个门,inputs, forget, output, 而中间cell用来管理结合他们生产需要的输出。在初始化结束之后,利用输入分别得到对应的门的输出,然后利用这三个门的信息分别更新cell和当前隐藏层状态。f 用来控制遗忘之前的信息和记忆当前信息的比例,进而更新cell,lstm可以看做是一种复杂的激活函数,它的存在依赖RNN的递归性。BasicLSTMCell只是个最基本的LSTM,而完整的LSTM可能比这个复杂,可以参看blog。
MultiRNNCell
对于MultiRNNCell,只能贴出完整代码来分析了:
- class MultiRNNCell(RNNCell):
- """RNN cell composed sequentially of multiple simple cells."""
- def __init__(self, cells, state_is_tuple=False):
- """Create a RNN cell composed sequentially of a number of RNNCells.
- Args:
- cells: list of RNNCells that will be composed in this order.
- state_is_tuple: If True, accepted and returned states are n-tuples, where
- `n = len(cells)`. By default (False), the states are all
- concatenated along the column axis.
- Raises:
- ValueError: if cells is empty (not allowed), or at least one of the cells
- returns a state tuple but the flag `state_is_tuple` is `False`.
- """
- if not cells:
- raise ValueError("Must specify at least one cell for MultiRNNCell.")
- self._cells = cells
- self._state_is_tuple = state_is_tuple
- if not state_is_tuple:
- if any(nest.is_sequence(c.state_size) for c in self._cells):
- raise ValueError("Some cells return tuples of states, but the flag "
- "state_is_tuple is not set. State sizes are: %s"
- % str([c.state_size for c in self._cells]))
- @property
- def state_size(self):
- if self._state_is_tuple:
- return tuple(cell.state_size for cell in self._cells)
- else:
- return sum([cell.state_size for cell in self._cells])
- @property
- def output_size(self):
- return self._cells[-1].output_size
- def __call__(self, inputs, state, scope=None):
- """Run this multi-layer cell on inputs, starting from state."""
- with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
- cur_state_pos = 0
- cur_inp = inputs
- new_states = []
- for i, cell in enumerate(self._cells):
- with vs.variable_scope("Cell%d" % i):
- if self._state_is_tuple:
- if not nest.is_sequence(state):
- raise ValueError(
- "Expected state to be a tuple of length %d, but received: %s"
- % (len(self.state_size), state))
- cur_state = state[i]
- else:
- cur_state = array_ops.slice(
- state, [0, cur_state_pos], [-1, cell.state_size])
- cur_state_pos += cell.state_size
- cur_inp, new_state = cell(cur_inp, cur_state)
- new_states.append(new_state)
- new_states = (tuple(new_states) if self._state_is_tuple
- else array_ops.concat(1, new_states))
- return cur_inp, new_states
创建对象时,可以看到初始化函数中不再是输入,而是变成了cells,,即一个cell是一层,多个cell便有多层RNNcell。而在使用对象时,单层可以看做多层的特例,对于输入inputs和state,同时得到多个cell的当前隐藏层状态,用作下个时间步。看似麻烦,其实很简洁,就是加入了对多个cell的计算,最后得到的新的隐藏层状态即每个cell的上个时间步的输出。
TF_RNNCell的更多相关文章
随机推荐
- JSX AS DSL? 写个 Mock API 服务器看看
这几天打算写一个简单的 API Mock 服务器,老生常谈哈?其实我是想讲 JSX, Mock 服务器只是一个幌子. 我在寻找一种更简洁.方便.同时又可以灵活扩展的.和别人不太一样的方式,来定义各种 ...
- ASP.NET Core 入门笔记10,ASP.NET Core 中间件(Middleware)入门
一.前言 1.本教程主要内容 ASP.NET Core 中间件介绍 通过自定义 ASP.NET Core 中间件实现请求验签 2.本教程环境信息 软件/环境 说明 操作系统 Windows 10 SD ...
- 数据传输协议protobuf的使用及案例
一.交互流程图: 总结点: 问题:一开始设置http请求中content-type 设置为默认文本格式,导致使用http传输body信息的时候必须进行base64加密才可以传输,这样会导致增加传输1/ ...
- 关于js中this指向的问题
this的绑定规则有4种 默认绑定 隐性绑定 显性绑定 new绑定 this绑定优先级 new 绑定 > 显性绑定 > 隐性绑定 > 默认绑定 1.如果函数被new 修饰 this绑 ...
- PDO原生分页
** PDO分页** 1.PDO连接数据库 $dbh=new PDO('mysql:host=127.0.0.1;dbname=03a','root','root');//使用pdo 2.接收页码 $ ...
- Keepalive+双主
一.建立3台服务器之间ssh互信在mydb1,mydb2,mydb3服务器上分别执行:ssh-keygen -t rsassh-copy-id -i .ssh/id_rsa.pub root@192. ...
- T100弹出是否确认窗体方式
例如: IF NOT cl_ask_confirm('aim-00108') THEN CALL s_transaction_end(') CALL cl_err_collect_show() RET ...
- python-迭代器实现异步(在串行中)
import timedef consumer(name): print('%s 准备吃包子啦!' %name) while True: baozi = yield #yield不但可以返回值还可以接 ...
- Nmap 常用命令语法
Nmap是一个网络连接端扫描软件,用来扫描网上电脑开放的网络连接端,确定哪些服务运行在哪些连接端,并且推断计算机运行哪个操作系统,正如大多数被用于网络安全的工具,Nmap也是不少黑客及骇客爱用的工具, ...
- 啥叫K8s?啥是k8s?
•Kubernetes介绍 1.背景介绍 云计算飞速发展 - IaaS - PaaS - SaaS Docker技术突飞猛进 - 一次构建,到处运行 - 容器的快速轻量 - 完整的生态环境 2.什么是 ...