一 RNN概述
    前面我们叙述了BP算法, CNN算法, 那么为什么还会有RNN呢?? 什么是RNN, 它到底有什么不同之处? RNN的主要应用领域有哪些呢?这些都是要讨论的问题.

1) BP算法,CNN之后, 为什么还有RNN?

细想BP算法,CNN(卷积神经网络)我们会发现, 他们的输出都是只考虑前一个输入的影响而不考虑其它时刻输入的影响, 比如简单的猫,狗,手写数字等单个物体的识别具有较好的效果. 但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了.因此, RNN就应运而生了.

2) 什么是RNN?

RNN是一种特殊的神经网络结构, 它是根据"人的认知是基于过往的经验和记忆"这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种'记忆'功能.

RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。

3) RNN的主要应用领域有哪些呢?

RNN的应用领域有很多, 可以说只要考虑时间先后顺序的问题都可以使用RNN来解决.这里主要说一下几个常见的应用领域:

① 自然语言处理(NLP): 主要有视频处理, 文本生成, 语言模型, 图像处理

② 机器翻译, 机器写小说

③ 语音识别

④ 图像描述生成

⑤ 文本相似度计算

⑥ 音乐推荐、网易考拉商品推荐、Youtube视频推荐等新的应用领域.

二 RNN(循环神经网络)
    1) RNN模型结构
    前面我们说了RNN具有时间"记忆"的功能, 那么它是怎么实现所谓的"记忆"的呢?

图1 RNN结构图

如图1所示, 我们可以看到RNN层级结构较之于CNN来说比较简单, 它主要有输入层,Hidden Layer, 输出层组成.

并且会发现在Hidden Layer 有一个箭头表示数据的循环更新, 这个就是实现时间记忆功能的方法.

如果到这里你还是没有搞懂RNN到底是什么意思,那么请继续往下看!

图2 Hidden Layer的层级展开图

如图2所示为Hidden Layer的层级展开图. t-1, t, t+1表示时间序列. X表示输入的样本. St表示样本在时间t处的的记忆,St = f(W*St-1 +U*Xt). W表示输入的权重, U表示此刻输入的样本的权重, V表示输出的样本权重.

在t =1时刻, 一般初始化输入S0=0, 随机初始化W,U,V, 进行下面的公式计算:

其中,f和g均为激活函数. 其中f可以是tanh,relu,sigmoid等激活函数,g通常是softmax也可以是其他。

时间就向前推进,此时的状态s1作为时刻1的记忆状态将参与下一个时刻的预测活动,也就是:

以此类推, 可以得到最终的输出值为:

注意: 1. 这里的W,U,V在每个时刻都是相等的(权重共享).

2. 隐藏状态可以理解为:  S=f(现有的输入+过去记忆总结)

2) RNN的反向传播
    前面我们介绍了RNN的前向传播的方式, 那么RNN的权重参数W,U,V都是怎么更新的呢?

每一次的输出值Ot都会产生一个误差值Et, 则总的误差可以表示为:.

则损失函数可以使用交叉熵损失函数也可以使用平方误差损失函数.

由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.(不熟悉BP的可以参考这里)

也就是要求参数的梯度:

首先我们求解W的更新方法, 由前面的W的更新可以看出它是每个时刻的偏差的偏导数之和.

在这里我们以 t = 3时刻为例, 根据链式求导法则可以得到t = 3时刻的偏导数为:

此时, 根据公式我们会发现, S3除了和W有关之外, 还和前一时刻S2有关.

对于S3直接展开得到下面的式子:

对于S2直接展开得到下面的式子:

对于S1直接展开得到下面的式子:

将上述三个式子合并得到:

这样就得到了公式:

这里要说明的是:表示的是S3对W直接求导, 不考虑S2的影响.(也就是例如y = f(x)*g(x)对x求导一样)

其次是对U的更新方法. 由于参数U求解和W求解类似,这里就不在赘述了,最终得到的具体的公式如下:

最后,给出V的更新公式(V只和输出O有关):

三 RNN的一些改进算法
    前面我们介绍了RNN的算法, 它处理时间序列的问题的效果很好, 但是仍然存在着一些问题, 其中较为严重的是容易出现梯度消失或者梯度爆炸的问题(BP算法和长时间依赖造成的). 注意: 这里的梯度消失和BP的不一样,这里主要指由于时间过长而造成记忆值较小的现象.

因此, 就出现了一系列的改进的算法, 这里介绍主要的两种算法: LSTM 和 GRU.

LSTM 和 GRU对于梯度消失或者梯度爆炸的问题处理方法主要是:

对于梯度消失: 由于它们都有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题。

对于梯度爆炸:用来克服梯度爆炸的问题就是gradient clipping,也就是当你计算的梯度超过阈值c或者小于阈值-c的时候,便把此时的梯度设置成c或-c。

1) LSTM算法(Long Short Term Memory, 长短期记忆网络 ) --- 重要的目前使用最多的时间序列算法

图3 LSTM算法结构图

如图3为LSTM算法的结构图.

和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.

用于遗忘的门叫做"遗忘门", 用于信息增加的叫做"信息增加门",最后是用于输出的"输出门". 这里就不展开介绍了.

此外,LSTM算法的还有一些变种.

如图4所示, 它增加“peephole connections”层 , 让门层也接受细胞状态的输入.

图4 LSTM算法的一个变种

如图5所示为LSTM的另外一种变种算法.它是通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增加什么信息,而是一起进行考虑。

图5 LSTM算法的一个变种

2) GRU算法

GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.

其各个部分满足关系式如下:

四 基于Tensorflow的基本操作和总结
    使用tensorflow的基本操作如下:

# _*_coding:utf-8_*_

import tensorflow as tf
import numpy as np

'''
TensorFlow中的RNN的API主要包括以下两个路径:
1) tf.nn.rnn_cell(主要定义RNN的几种常见的cell)
2) tf.nn(RNN中的辅助操作)
'''
# 一 RNN中的cell
# 基类(最顶级的父类): tf.nn.rnn_cell.RNNCell()
# 最基础的RNN的实现: tf.nn.rnn_cell.BasicRNNCell()
# 简单的LSTM cell实现: tf.nn.rnn_cell.BasicLSTMCell()
# 最常用的LSTM实现: tf.nn.rnn_cell.LSTMCell()
# RGU cell实现: tf.nn.rnn_cell.GRUCell()
# 多层RNN结构网络的实现: tf.nn.rnn_cell.MultiRNNCell()

# 创建cell
# cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
# print(cell.state_size)
# print(cell.output_size)

# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
# inputs = tf.placeholder(dtype=tf.float32, shape=[4, 64])

# 给定RNN的初始状态
# s0 = cell.zero_state(4, tf.float32)
# print(s0.get_shape())

# 对于t=1时刻传入输入和state0,获取结果值
# output, s1 = cell.call(inputs, s0)
# print(output.get_shape())
# print(s1.get_shape())

# 定义LSTM cell
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
inputs = tf.placeholder(tf.float32, shape=[4, 48])
# 给定初始状态
s0 = lstm_cell.zero_state(4, tf.float32)
# 对于t=1时刻传入输入和state0,获取结果值
output, s1 = lstm_cell.call(inputs, s0)
print(output.get_shape())
print(s1.h.get_shape())
print(s1.c.get_shape())
    当然, 你可能会发现使用cell.call()每次只能调用一个得到一个状态, 如有多个状态需要多次重复调用较为麻烦, 那么我们怎么解决的呢? 可以参照后面的基于RNN的手写数字识别和单词预测的实例查找解决方法.

本文主要介绍了一种时间序列的RNN神经网络及其基础上衍生出来的变种算法LSTM和GRU算法, 也对RNN算法的使用场景作了介绍.

当然, 由于篇幅限制, 这里对于双向RNNs和多层的RNNs没有介绍. 另外, 对于LSTM的参数更新算法在这里也没有介绍, 后续补上吧!

最后, 如果你发现了任何问题, 欢迎一起探讨, 共同进步!!
---------------------

RNN概述-深度学习 -神经网络的更多相关文章

  1. 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD

    上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...

  2. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  3. pytorch深度学习神经网络实现手写字体识别

    利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...

  4. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  5. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  6. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

  7. GitHub 上 57 款最流行的开源深度学习项目

    转载:https://www.oschina.net/news/79500/57-most-popular-deep-learning-project-at-github GitHub 上 57 款最 ...

  8. Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包

    Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包 用 CNTK 搞深度学习 (一) 入门 Computational Network Toolk ...

  9. Github上Stars最多的53个深度学习项目,TensorFlow遥遥领先

    原文:https://github.com/aymericdamien/TopDeepLearning 项目名称 Stars 项目介绍 TensorFlow 29622 使用数据流图计算可扩展机器学习 ...

随机推荐

  1. NLP语言模型

    语言模型: I. 基本思想 区别于其他大多数检索模型从查询到文档(即给定用户查询,如何找出相关的文档), 语言模型由文档到查询,即为每个文档建立不同的语言模型,判断由文档生成用户查 询的可能性有多大, ...

  2. iOS-WebView(WKWebView)进度条

    一直以来,就有想通过技术博客来记录总结下自己工作中碰到的问题的想法,这个想法拖了好久今天才开始着手写自己的第一篇技术博客,由于刚开始写,不免会出现不对的地方,希望各位看到的大牛多多指教.好了,不多说了 ...

  3. Linux下安装 nginx

    安装依赖 yum install gcc yum install pcre-devel yum install zlib zlib-devel yum install openssl openssl- ...

  4. VMware10安装CentOS7

    先去网上下载一个VMware的破解版或者激活版,安装配置这里就不介绍了自行下载安装,基本过程相当于windows下安装个软件而已. CentOS7镜像下载就下阿里云站点的这是链接 http://mir ...

  5. 转:30分钟了解Springboot整合Shiro

    引自:30分钟了解Springboot整合Shiro 前言:06年7月的某日,不才创作了一篇题为<30分钟学会如何使用Shiro>的文章.不在意之间居然斩获了22万的阅读量,许多人因此加了 ...

  6. vue组件间传值详解

    1.父传子----传值要点: <1> 在组件注册的时候必须要使用 return 去返回 data对象;

  7. Android中,子线程使用主线程中的组件出现问题的解决方法

    Android中,主线程中的组件,不能被子线程调用,否则就会出现异常. 这里所使用的方法就是利用Handler类中的Callback(),接受线程中的Message类发来的消息,然后把所要在线程中执行 ...

  8. git 完善使用中

    GIT 版本库控制: 第一步:Git 的账号注册 url :https://github.com/ 这是git的官网如果第一次打开会这样 中间红色圈内是注册 内容, 第一项是用户名 第二项是邮箱 第三 ...

  9. s3c2440中断控制器操作

    一.ARM中断体系结构 arm有7中异常工作模式 用户模式.快中断模式.管理模式.数据访问终止模式.中断模式.系统模式.未定义指令终止模式. 几种模式有什么不同呢, 1.不同的寄存器 2.不同的权限 ...

  10. 『Linux基础 - 3』 Linux文件目录介绍

    Windows 和 Linux 文件系统区别 -- 结构 Windows 下的文件系统 - 在 Windows 下,打开 "计算机",我们看到的是一个个的驱动器盘符: - 每个驱动 ...