

  • BasicRNNCell
  • GRUCell
  • BasicLSTMCell
  • LSTMCell
  • MultiRNNCell


所有的rnncell均继承于RNNCell, RNNCell主要定义了几个抽象方法:

  1. def __call__(self, inputs, state, scope=None):
  2. raise NotImplementedError("Abstract method")
  4. @property
  5. def state_size(self):
  6. raise NotImplementedError("Abstract method")
  8. @property
  9. def output_size(self):
  10. raise NotImplementedError("Abstract method")





  1. def __init__(self, num_units, input_size=None, activation=tanh):
  2. if input_size is not None:
  3. logging.warn("%s: The input_size parameter is deprecated.", self)
  4. self._num_units = num_units
  5. self._activation = activation


  1. def __call__(self, inputs, state, scope=None):
  2. """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
  3. with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
  4. output = self._activation(_linear([inputs, state], self._num_units, True))
  5. return output, output





  1. def __call__(self, inputs, state, scope=None):
  2. """Gated recurrent unit (GRU) with nunits cells."""
  3. with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
  4. with vs.variable_scope("Gates"): # Reset gate and update gate.
  5. # We start with bias of 1.0 to not reset and not update.
  6. r, u = array_ops.split(1, 2, _linear([inputs, state],
  7. 2 * self._num_units, True, 1.0))
  8. r, u = sigmoid(r), sigmoid(u)
  9. with vs.variable_scope("Candidate"):
  10. c = self._activation(_linear([inputs, r * state],
  11. self._num_units, True))
  12. new_h = u * state + (1 - u) * c
  13. return new_h, new_h

GRUCell的初始化与RNN一样,给出输入和初始化的state,在使用对象时,利用输入和前一个时间的隐藏层状态,得到对应的Gates: r, u, 然后利用r更新cell状态,最后利用u得到新的隐藏层状态。对于RNN的改进,最厉害的莫过于下面的,而且有很多变种,这里tensorflow中只有几个简单常见的cell。接下来,我们开始看看LSTM。




  1. def __call__(self, inputs, state, scope=None):
  2. """Long short-term memory cell (LSTM)."""
  3. with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
  4. # Parameters of gates are concatenated into one multiply for efficiency.
  5. if self._state_is_tuple:
  6. c, h = state
  7. else:
  8. c, h = array_ops.split(1, 2, state)
  9. concat = _linear([inputs, h], 4 * self._num_units, True)
  11. # i = input_gate, j = new_input, f = forget_gate, o = output_gate
  12. i, j, f, o = array_ops.split(1, 4, concat)
  14. new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
  15. self._activation(j))
  16. new_h = self._activation(new_c) * sigmoid(o)
  18. if self._state_is_tuple:
  19. new_state = LSTMStateTuple(new_c, new_h)
  20. else:
  21. new_state = array_ops.concat(1, [new_c, new_h])
  22. return new_h, new_state

lstm有三个门,inputs, forget, output, 而中间cell用来管理结合他们生产需要的输出。在初始化结束之后,利用输入分别得到对应的门的输出,然后利用这三个门的信息分别更新cell和当前隐藏层状态。f 用来控制遗忘之前的信息和记忆当前信息的比例,进而更新cell,lstm可以看做是一种复杂的激活函数,它的存在依赖RNN的递归性。BasicLSTMCell只是个最基本的LSTM,而完整的LSTM可能比这个复杂,可以参看blog



  1. class MultiRNNCell(RNNCell):
  2. """RNN cell composed sequentially of multiple simple cells."""
  4. def __init__(self, cells, state_is_tuple=False):
  5. """Create a RNN cell composed sequentially of a number of RNNCells.
  7. Args:
  8. cells: list of RNNCells that will be composed in this order.
  9. state_is_tuple: If True, accepted and returned states are n-tuples, where
  10. `n = len(cells)`. By default (False), the states are all
  11. concatenated along the column axis.
  13. Raises:
  14. ValueError: if cells is empty (not allowed), or at least one of the cells
  15. returns a state tuple but the flag `state_is_tuple` is `False`.
  16. """
  17. if not cells:
  18. raise ValueError("Must specify at least one cell for MultiRNNCell.")
  19. self._cells = cells
  20. self._state_is_tuple = state_is_tuple
  21. if not state_is_tuple:
  22. if any(nest.is_sequence(c.state_size) for c in self._cells):
  23. raise ValueError("Some cells return tuples of states, but the flag "
  24. "state_is_tuple is not set. State sizes are: %s"
  25. % str([c.state_size for c in self._cells]))
  27. @property
  28. def state_size(self):
  29. if self._state_is_tuple:
  30. return tuple(cell.state_size for cell in self._cells)
  31. else:
  32. return sum([cell.state_size for cell in self._cells])
  34. @property
  35. def output_size(self):
  36. return self._cells[-1].output_size
  38. def __call__(self, inputs, state, scope=None):
  39. """Run this multi-layer cell on inputs, starting from state."""
  40. with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
  41. cur_state_pos = 0
  42. cur_inp = inputs
  43. new_states = []
  44. for i, cell in enumerate(self._cells):
  45. with vs.variable_scope("Cell%d" % i):
  46. if self._state_is_tuple:
  47. if not nest.is_sequence(state):
  48. raise ValueError(
  49. "Expected state to be a tuple of length %d, but received: %s"
  50. % (len(self.state_size), state))
  51. cur_state = state[i]
  52. else:
  53. cur_state = array_ops.slice(
  54. state, [0, cur_state_pos], [-1, cell.state_size])
  55. cur_state_pos += cell.state_size
  56. cur_inp, new_state = cell(cur_inp, cur_state)
  57. new_states.append(new_state)
  58. new_states = (tuple(new_states) if self._state_is_tuple
  59. else array_ops.concat(1, new_states))
  60. return cur_inp, new_states




