参考:链接

RNNCell

  • BasicRNNCell
  • GRUCell
  • BasicLSTMCell
  • LSTMCell
  • MultiRNNCell

抽象类RNNCell

所有的rnncell均继承于RNNCell, RNNCell主要定义了几个抽象方法:

  1. def __call__(self, inputs, state, scope=None):
  2. raise NotImplementedError("Abstract method")
  3.  
  4. @property
  5. def state_size(self):
  6. raise NotImplementedError("Abstract method")
  7.  
  8. @property
  9. def output_size(self):
  10. raise NotImplementedError("Abstract method")

上述方法,__call__在对象被使用时调用,其他可以看做属性方法,主要用作获取状态state的大小,cell的输出大小。既然对象使用时会调用__call__,那么各类RNN的操作都定义在这个方法中。接下来,我们就针对各个不同的cell来详细介绍各类RNN。

BasicRNNCell

这个cell是最基础的一个RNNCell,可以看做是对一般全连接层的拓展,除了在水平方向加入时序关系,可以用下图表示:

而BasicRNNCell的初始化方法可如代码所示:

  1. def __init__(self, num_units, input_size=None, activation=tanh):
  2. if input_size is not None:
  3. logging.warn("%s: The input_size parameter is deprecated.", self)
  4. self._num_units = num_units
  5. self._activation = activation

初始化只需要给出num_units,用来指有多少个隐藏层单元;而activation指使用哪种激活函数用作激活输出。而对应的RNN操作定义在__call__方法中:

  1. def __call__(self, inputs, state, scope=None):
  2. """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
  3. with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
  4. output = self._activation(_linear([inputs, state], self._num_units, True))
  5. return output, output

很清晰,inputs表示隐藏层的输入,state表示上个时间的隐藏层状态,也可以说是上一次隐藏层向自身的输出,对于第一次输入,则需要初始化state,对应初始化方法有很多种,可以使用tensorflow提供的各种初始化函数。在__call__中,对输入inputsstate进行activation(wx+b),用作下次的输入。

GRUCell

GRU是对RNN的一种改进,相比LSTM来说,也可以看做是对LSTM的一种简化,是Bengio在14年提出来的,用作机器翻译。先看一下GRU的基本结构:

这里我们结合代码来看原理:

  1. def __call__(self, inputs, state, scope=None):
  2. """Gated recurrent unit (GRU) with nunits cells."""
  3. with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
  4. with vs.variable_scope("Gates"): # Reset gate and update gate.
  5. # We start with bias of 1.0 to not reset and not update.
  6. r, u = array_ops.split(1, 2, _linear([inputs, state],
  7. 2 * self._num_units, True, 1.0))
  8. r, u = sigmoid(r), sigmoid(u)
  9. with vs.variable_scope("Candidate"):
  10. c = self._activation(_linear([inputs, r * state],
  11. self._num_units, True))
  12. new_h = u * state + (1 - u) * c
  13. return new_h, new_h

GRUCell的初始化与RNN一样,给出输入和初始化的state,在使用对象时,利用输入和前一个时间的隐藏层状态,得到对应的Gates: r, u, 然后利用r更新cell状态,最后利用u得到新的隐藏层状态。对于RNN的改进,最厉害的莫过于下面的,而且有很多变种,这里tensorflow中只有几个简单常见的cell。接下来,我们开始看看LSTM。

BasicLSTMCell

这个cell可以看做是最简单的LSTM,在每个连接中没有额外的连接,即其他变种在连接中加入各种改进。对于BasicLSTMCell,可以如下图所示:

同样的,我们结合代码来看它的原理:

  1. def __call__(self, inputs, state, scope=None):
  2. """Long short-term memory cell (LSTM)."""
  3. with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
  4. # Parameters of gates are concatenated into one multiply for efficiency.
  5. if self._state_is_tuple:
  6. c, h = state
  7. else:
  8. c, h = array_ops.split(1, 2, state)
  9. concat = _linear([inputs, h], 4 * self._num_units, True)
  10.  
  11. # i = input_gate, j = new_input, f = forget_gate, o = output_gate
  12. i, j, f, o = array_ops.split(1, 4, concat)
  13.  
  14. new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
  15. self._activation(j))
  16. new_h = self._activation(new_c) * sigmoid(o)
  17.  
  18. if self._state_is_tuple:
  19. new_state = LSTMStateTuple(new_c, new_h)
  20. else:
  21. new_state = array_ops.concat(1, [new_c, new_h])
  22. return new_h, new_state

lstm有三个门,inputs, forget, output, 而中间cell用来管理结合他们生产需要的输出。在初始化结束之后,利用输入分别得到对应的门的输出,然后利用这三个门的信息分别更新cell和当前隐藏层状态。f 用来控制遗忘之前的信息和记忆当前信息的比例,进而更新cell,lstm可以看做是一种复杂的激活函数,它的存在依赖RNN的递归性。BasicLSTMCell只是个最基本的LSTM,而完整的LSTM可能比这个复杂,可以参看blog

MultiRNNCell

对于MultiRNNCell,只能贴出完整代码来分析了:

  1. class MultiRNNCell(RNNCell):
  2. """RNN cell composed sequentially of multiple simple cells."""
  3.  
  4. def __init__(self, cells, state_is_tuple=False):
  5. """Create a RNN cell composed sequentially of a number of RNNCells.
  6.  
  7. Args:
  8. cells: list of RNNCells that will be composed in this order.
  9. state_is_tuple: If True, accepted and returned states are n-tuples, where
  10. `n = len(cells)`. By default (False), the states are all
  11. concatenated along the column axis.
  12.  
  13. Raises:
  14. ValueError: if cells is empty (not allowed), or at least one of the cells
  15. returns a state tuple but the flag `state_is_tuple` is `False`.
  16. """
  17. if not cells:
  18. raise ValueError("Must specify at least one cell for MultiRNNCell.")
  19. self._cells = cells
  20. self._state_is_tuple = state_is_tuple
  21. if not state_is_tuple:
  22. if any(nest.is_sequence(c.state_size) for c in self._cells):
  23. raise ValueError("Some cells return tuples of states, but the flag "
  24. "state_is_tuple is not set. State sizes are: %s"
  25. % str([c.state_size for c in self._cells]))
  26.  
  27. @property
  28. def state_size(self):
  29. if self._state_is_tuple:
  30. return tuple(cell.state_size for cell in self._cells)
  31. else:
  32. return sum([cell.state_size for cell in self._cells])
  33.  
  34. @property
  35. def output_size(self):
  36. return self._cells[-1].output_size
  37.  
  38. def __call__(self, inputs, state, scope=None):
  39. """Run this multi-layer cell on inputs, starting from state."""
  40. with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
  41. cur_state_pos = 0
  42. cur_inp = inputs
  43. new_states = []
  44. for i, cell in enumerate(self._cells):
  45. with vs.variable_scope("Cell%d" % i):
  46. if self._state_is_tuple:
  47. if not nest.is_sequence(state):
  48. raise ValueError(
  49. "Expected state to be a tuple of length %d, but received: %s"
  50. % (len(self.state_size), state))
  51. cur_state = state[i]
  52. else:
  53. cur_state = array_ops.slice(
  54. state, [0, cur_state_pos], [-1, cell.state_size])
  55. cur_state_pos += cell.state_size
  56. cur_inp, new_state = cell(cur_inp, cur_state)
  57. new_states.append(new_state)
  58. new_states = (tuple(new_states) if self._state_is_tuple
  59. else array_ops.concat(1, new_states))
  60. return cur_inp, new_states

创建对象时,可以看到初始化函数中不再是输入,而是变成了cells,,即一个cell是一层,多个cell便有多层RNNcell。而在使用对象时,单层可以看做多层的特例,对于输入inputs和state,同时得到多个cell的当前隐藏层状态,用作下个时间步。看似麻烦,其实很简洁,就是加入了对多个cell的计算,最后得到的新的隐藏层状态即每个cell的上个时间步的输出。

TF_RNNCell的更多相关文章

随机推荐

  1. JSX AS DSL? 写个 Mock API 服务器看看

    这几天打算写一个简单的 API Mock 服务器,老生常谈哈?其实我是想讲 JSX, Mock 服务器只是一个幌子. 我在寻找一种更简洁.方便.同时又可以灵活扩展的.和别人不太一样的方式,来定义各种 ...

  2. ASP.NET Core 入门笔记10,ASP.NET Core 中间件(Middleware)入门

    一.前言 1.本教程主要内容 ASP.NET Core 中间件介绍 通过自定义 ASP.NET Core 中间件实现请求验签 2.本教程环境信息 软件/环境 说明 操作系统 Windows 10 SD ...

  3. 数据传输协议protobuf的使用及案例

    一.交互流程图: 总结点: 问题:一开始设置http请求中content-type 设置为默认文本格式,导致使用http传输body信息的时候必须进行base64加密才可以传输,这样会导致增加传输1/ ...

  4. 关于js中this指向的问题

    this的绑定规则有4种 默认绑定 隐性绑定 显性绑定 new绑定 this绑定优先级 new 绑定 > 显性绑定 > 隐性绑定 > 默认绑定 1.如果函数被new 修饰 this绑 ...

  5. PDO原生分页

    ** PDO分页** 1.PDO连接数据库 $dbh=new PDO('mysql:host=127.0.0.1;dbname=03a','root','root');//使用pdo 2.接收页码 $ ...

  6. Keepalive+双主

    一.建立3台服务器之间ssh互信在mydb1,mydb2,mydb3服务器上分别执行:ssh-keygen -t rsassh-copy-id -i .ssh/id_rsa.pub root@192. ...

  7. T100弹出是否确认窗体方式

    例如: IF NOT cl_ask_confirm('aim-00108') THEN CALL s_transaction_end(') CALL cl_err_collect_show() RET ...

  8. python-迭代器实现异步(在串行中)

    import timedef consumer(name): print('%s 准备吃包子啦!' %name) while True: baozi = yield #yield不但可以返回值还可以接 ...

  9. Nmap 常用命令语法

    Nmap是一个网络连接端扫描软件,用来扫描网上电脑开放的网络连接端,确定哪些服务运行在哪些连接端,并且推断计算机运行哪个操作系统,正如大多数被用于网络安全的工具,Nmap也是不少黑客及骇客爱用的工具, ...

  10. 啥叫K8s?啥是k8s?

    •Kubernetes介绍 1.背景介绍 云计算飞速发展 - IaaS - PaaS - SaaS Docker技术突飞猛进 - 一次构建,到处运行 - 容器的快速轻量 - 完整的生态环境 2.什么是 ...