RNN概述-深度学习 -神经网络
一 RNN概述
前面我们叙述了BP算法, CNN算法, 那么为什么还会有RNN呢?? 什么是RNN, 它到底有什么不同之处? RNN的主要应用领域有哪些呢?这些都是要讨论的问题.
1) BP算法,CNN之后, 为什么还有RNN?
细想BP算法,CNN(卷积神经网络)我们会发现, 他们的输出都是只考虑前一个输入的影响而不考虑其它时刻输入的影响, 比如简单的猫,狗,手写数字等单个物体的识别具有较好的效果. 但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了.因此, RNN就应运而生了.
2) 什么是RNN?
RNN是一种特殊的神经网络结构, 它是根据"人的认知是基于过往的经验和记忆"这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种'记忆'功能.
RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。
3) RNN的主要应用领域有哪些呢?
RNN的应用领域有很多, 可以说只要考虑时间先后顺序的问题都可以使用RNN来解决.这里主要说一下几个常见的应用领域:
① 自然语言处理(NLP): 主要有视频处理, 文本生成, 语言模型, 图像处理
② 机器翻译, 机器写小说
③ 语音识别
④ 图像描述生成
⑤ 文本相似度计算
⑥ 音乐推荐、网易考拉商品推荐、Youtube视频推荐等新的应用领域.
二 RNN(循环神经网络)
1) RNN模型结构
前面我们说了RNN具有时间"记忆"的功能, 那么它是怎么实现所谓的"记忆"的呢?
图1 RNN结构图
如图1所示, 我们可以看到RNN层级结构较之于CNN来说比较简单, 它主要有输入层,Hidden Layer, 输出层组成.
并且会发现在Hidden Layer 有一个箭头表示数据的循环更新, 这个就是实现时间记忆功能的方法.
如果到这里你还是没有搞懂RNN到底是什么意思,那么请继续往下看!
图2 Hidden Layer的层级展开图
如图2所示为Hidden Layer的层级展开图. t-1, t, t+1表示时间序列. X表示输入的样本. St表示样本在时间t处的的记忆,St = f(W*St-1 +U*Xt). W表示输入的权重, U表示此刻输入的样本的权重, V表示输出的样本权重.
在t =1时刻, 一般初始化输入S0=0, 随机初始化W,U,V, 进行下面的公式计算:
其中,f和g均为激活函数. 其中f可以是tanh,relu,sigmoid等激活函数,g通常是softmax也可以是其他。
时间就向前推进,此时的状态s1作为时刻1的记忆状态将参与下一个时刻的预测活动,也就是:
以此类推, 可以得到最终的输出值为:
注意: 1. 这里的W,U,V在每个时刻都是相等的(权重共享).
2. 隐藏状态可以理解为: S=f(现有的输入+过去记忆总结)
2) RNN的反向传播
前面我们介绍了RNN的前向传播的方式, 那么RNN的权重参数W,U,V都是怎么更新的呢?
每一次的输出值Ot都会产生一个误差值Et, 则总的误差可以表示为:.
则损失函数可以使用交叉熵损失函数也可以使用平方误差损失函数.
由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.(不熟悉BP的可以参考这里)
也就是要求参数的梯度:
首先我们求解W的更新方法, 由前面的W的更新可以看出它是每个时刻的偏差的偏导数之和.
在这里我们以 t = 3时刻为例, 根据链式求导法则可以得到t = 3时刻的偏导数为:
此时, 根据公式我们会发现, S3除了和W有关之外, 还和前一时刻S2有关.
对于S3直接展开得到下面的式子:
对于S2直接展开得到下面的式子:
对于S1直接展开得到下面的式子:
将上述三个式子合并得到:
这样就得到了公式:
这里要说明的是:表示的是S3对W直接求导, 不考虑S2的影响.(也就是例如y = f(x)*g(x)对x求导一样)
其次是对U的更新方法. 由于参数U求解和W求解类似,这里就不在赘述了,最终得到的具体的公式如下:
最后,给出V的更新公式(V只和输出O有关):
三 RNN的一些改进算法
前面我们介绍了RNN的算法, 它处理时间序列的问题的效果很好, 但是仍然存在着一些问题, 其中较为严重的是容易出现梯度消失或者梯度爆炸的问题(BP算法和长时间依赖造成的). 注意: 这里的梯度消失和BP的不一样,这里主要指由于时间过长而造成记忆值较小的现象.
因此, 就出现了一系列的改进的算法, 这里介绍主要的两种算法: LSTM 和 GRU.
LSTM 和 GRU对于梯度消失或者梯度爆炸的问题处理方法主要是:
对于梯度消失: 由于它们都有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题。
对于梯度爆炸:用来克服梯度爆炸的问题就是gradient clipping,也就是当你计算的梯度超过阈值c或者小于阈值-c的时候,便把此时的梯度设置成c或-c。
1) LSTM算法(Long Short Term Memory, 长短期记忆网络 ) --- 重要的目前使用最多的时间序列算法
图3 LSTM算法结构图
如图3为LSTM算法的结构图.
和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.
用于遗忘的门叫做"遗忘门", 用于信息增加的叫做"信息增加门",最后是用于输出的"输出门". 这里就不展开介绍了.
此外,LSTM算法的还有一些变种.
如图4所示, 它增加“peephole connections”层 , 让门层也接受细胞状态的输入.
图4 LSTM算法的一个变种
如图5所示为LSTM的另外一种变种算法.它是通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增加什么信息,而是一起进行考虑。
图5 LSTM算法的一个变种
2) GRU算法
GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.
其各个部分满足关系式如下:
四 基于Tensorflow的基本操作和总结
使用tensorflow的基本操作如下:
# _*_coding:utf-8_*_
import tensorflow as tf
import numpy as np
'''
TensorFlow中的RNN的API主要包括以下两个路径:
1) tf.nn.rnn_cell(主要定义RNN的几种常见的cell)
2) tf.nn(RNN中的辅助操作)
'''
# 一 RNN中的cell
# 基类(最顶级的父类): tf.nn.rnn_cell.RNNCell()
# 最基础的RNN的实现: tf.nn.rnn_cell.BasicRNNCell()
# 简单的LSTM cell实现: tf.nn.rnn_cell.BasicLSTMCell()
# 最常用的LSTM实现: tf.nn.rnn_cell.LSTMCell()
# RGU cell实现: tf.nn.rnn_cell.GRUCell()
# 多层RNN结构网络的实现: tf.nn.rnn_cell.MultiRNNCell()
# 创建cell
# cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
# print(cell.state_size)
# print(cell.output_size)
# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
# inputs = tf.placeholder(dtype=tf.float32, shape=[4, 64])
# 给定RNN的初始状态
# s0 = cell.zero_state(4, tf.float32)
# print(s0.get_shape())
# 对于t=1时刻传入输入和state0,获取结果值
# output, s1 = cell.call(inputs, s0)
# print(output.get_shape())
# print(s1.get_shape())
# 定义LSTM cell
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)
# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征
inputs = tf.placeholder(tf.float32, shape=[4, 48])
# 给定初始状态
s0 = lstm_cell.zero_state(4, tf.float32)
# 对于t=1时刻传入输入和state0,获取结果值
output, s1 = lstm_cell.call(inputs, s0)
print(output.get_shape())
print(s1.h.get_shape())
print(s1.c.get_shape())
当然, 你可能会发现使用cell.call()每次只能调用一个得到一个状态, 如有多个状态需要多次重复调用较为麻烦, 那么我们怎么解决的呢? 可以参照后面的基于RNN的手写数字识别和单词预测的实例查找解决方法.
本文主要介绍了一种时间序列的RNN神经网络及其基础上衍生出来的变种算法LSTM和GRU算法, 也对RNN算法的使用场景作了介绍.
当然, 由于篇幅限制, 这里对于双向RNNs和多层的RNNs没有介绍. 另外, 对于LSTM的参数更新算法在这里也没有介绍, 后续补上吧!
最后, 如果你发现了任何问题, 欢迎一起探讨, 共同进步!!
---------------------
RNN概述-深度学习 -神经网络的更多相关文章
- 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD
上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...
- tensorflow模型持久化保存和加载--深度学习-神经网络
模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...
- pytorch深度学习神经网络实现手写字体识别
利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下: 其具体实现代码如下所示:import torch ...
- 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...
- [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践
转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...
- TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN
前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...
- GitHub 上 57 款最流行的开源深度学习项目
转载:https://www.oschina.net/news/79500/57-most-popular-deep-learning-project-at-github GitHub 上 57 款最 ...
- Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包
Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包 用 CNTK 搞深度学习 (一) 入门 Computational Network Toolk ...
- Github上Stars最多的53个深度学习项目,TensorFlow遥遥领先
原文:https://github.com/aymericdamien/TopDeepLearning 项目名称 Stars 项目介绍 TensorFlow 29622 使用数据流图计算可扩展机器学习 ...
随机推荐
- Process.waitFor()导致主线程堵塞问题
今日开发的时候使用jdk自带的运行时变量 RunTime.getRunTime() 去执行bash命令.因为该bash操作耗时比较长,所以使用了Process.waitFor()去等待子线程运行结束. ...
- 竞赛题解 - NOIP2018 赛道修建
\(\mathcal {NOIP2018}\) 赛道修建 - 竞赛题解 额--考试的时候大概猜到正解,但是时间不够了,不敢写,就写了骗分QwQ 现在把坑填好了~ 题目 (Copy from 洛谷) 题 ...
- Ubuntu 16.04LTS 更新清华源
1 备份原来的更新源 cp /etc/apt/sources.list /etc/apt/sources.list.backup 如果提示权限不够就输入下面两行,先进入到超级用户,再备份 sudo - ...
- jQuery获取Select option 选择的Text和 Value
获取一组radio被选中项的值:var item = $('input[name=items][checked]').val();获取select被选中项的文本var item = $("s ...
- HTML5中的拖拽与拖放(drag&&drop)
1.drag 当拖动某个元素时,将会依次触发下列事件: 1)dragstart:按下鼠标键并开始移动鼠标时,会触发该事件 2)drag:dragstart触发后,随即便触发drag事件,而且在元素被拖 ...
- 月薪30-50K的大数据工程师们,他们背后是如何学习的
这两天小编去了解了下大数据开发相关职位的薪资,主要有hadoop工程师,数据挖掘工程师.大数据算法工程师等,从平均薪资来看,目前大数据相关岗位的月薪均在2万以上,随着项目经验的增长工资会越来越高. ...
- ruby做接口测试
一. 工具选择 IDE:rubymine:http接口请求:Unirest,ruby单元测试框架:rspec 二.工程创建 新建工程,在工程目录下,执行:rspec --init:初始化rspec工程 ...
- 闰年相关的问题v3.0——计算有多少闰年
# include<stdio.h>int main(){ int a,b,i; int sum = 0; printf("Input your birth year:" ...
- linux进程篇 (三) 进程间的通信1 管道通信
通信方式分4大类: 管道通信:无名管道 有名管道 信号通信:发送 接收 和 处理 IPC通信:共享内存 消息队列 信号灯 socke 网络通信 用户空间 进程A <----无法通信----> ...
- 北京Uber优步司机奖励政策(12月11日)
滴快车单单2.5倍,注册地址:http://www.udache.com/ 如何注册Uber司机(全国版最新最详细注册流程)/月入2万/不用抢单:http://www.cnblogs.com/mfry ...