lstm和gru详解
一、LSTM(长短期记忆网络)
LSTM是一种特殊的RNN类型,一般的RNN结构如下图所示,是一种将以往学习的结果应用到当前学习的模型,但是这种一般的RNN存在着许多的弊端。举个例子,如果我们要预测“the clouds are in the sky”的最后一个单词,因为只在这一个句子的语境中进行预测,那么将很容易地预测出是这个单词是sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。
标准的RNN结构中只有一个神经元,一个tanh层进行重复的学习,这样会存在一些弊端。例如,在比较长的环境中,例如在“I grew up in France… I speak fluent French”中去预测最后的French,那么模型会推荐一种语言的名字,但是预测具体是哪一种语言时就需要用到很远以前的Franch,这就说明在长环境中相关的信息和预测的词之间的间隔可以是非常长的。在理论上,RNN 绝对可以处理这样的长环境问题。人们可以仔细挑选参数来解决这类问题中的最初级形式,但在实践中,RNN 并不能够成功学习到这些知识。然而,LSTM模型就可以解决这一问题。
如图所示,标准LSTM模型是一种特殊的RNN类型,在每一个重复的模块中有四个特殊的结构,以一种特殊的方式进行交互。在图中,每一条黑线传输着一整个向量,粉色的圈代表一种pointwise 操作(将定义域上的每一点的函数值分别进行运算),诸如向量的和,而黄色的矩阵就是学习到的神经网络层。
LSTM模型的核心思想是“细胞状态”。“细胞状态”类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。
LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。
Sigmoid 层输出 0 到 1 之间的数值,描述每个部分有多少量可以通过。0 代表“不许任何量通过”,1 就指“允许任意量通过”。LSTM 拥有三个门,来保护和控制细胞状态。
在LSTM模型中,第一步是决定我们从“细胞”中丢弃什么信息,这个操作由一个忘记门层来完成。该层读取当前输入x和前神经元信息h,由ft来决定丢弃的信息。输出结果1表示“完全保留”,0 表示“完全舍弃”。
第二步是确定细胞状态所存放的新信息,这一步由两层组成。sigmoid层作为“输入门层”,决定我们将要更新的值i;tanh层来创建一个新的候选值向量~Ct加入到状态中。在语言模型的例子中,我们希望增加新的主语到细胞状态中,来替代旧的需要忘记的主语。
第三步就是更新旧细胞的状态,将Ct-1更新为Ct。我们把旧状态与 ft相乘,丢弃掉我们确定需要丢弃的信息。接着加上 it * ~Ct。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。在语言模型的例子中,这就是我们实际根据前面确定的目标,丢弃旧代词的信息并添加新的信息的地方。
最后一步就是确定输出了,这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。在语言模型的例子中,因为语境中有一个代词,可能需要输出与之相关的信息。例如,输出判断是一个动词,那么我们需要根据代词是单数还是负数,进行动词的词形变化。
二、GRU( Gated Recurrent Unit,LSTM变体)
GRU作为LSTM的一种变体,将忘记门和输入门合成了一个单一的更新门。同样还混合了细胞状态和隐藏状态,加诸其他一些改动。最终的模型比标准的 LSTM 模型要简单,也是非常流行的变体。
三、对比
概括的来说,LSTM和GRU都能通过各种Gate将重要特征保留,保证其在long-term 传播的时候也不会被丢失。
可以看出,标准LSTM和GRU的差别并不大,但是都比tanh要明显好很多,所以在选择标准LSTM或者GRU的时候还要看具体的任务是什么。
使用LSTM的原因之一是解决RNN Deep Network的Gradient错误累积太多,以至于Gradient归零或者成为无穷大,所以无法继续进行优化的问题。GRU的构造更简单:比LSTM少一个gate,这样就少几个矩阵乘法。在训练数据很大的情况下GRU能节省很多时间。
五、应用
样例:预测当前词。每个时刻的输入都是一个embedding向量,它的长度是输入层神经元的个数,与时间步的个数(即句子的长度)没有关系。
每个时刻的输出是一个概率分布向量,其中最大值的下标决定了输出哪个词。
RNN&LSTM实际应用:
1. Language ModelThe Unreasonable Effectiveness of Recurrent Neural Networks http://karpathy.github.io/2015/05/21/rnn-effectiveness/
2. Image Captioning[CVPR15]]Long-term Recurrent Convolutional Networks for Visual Recognition and DescriptionDeep Visual-Semantic Alignments for Generating Image Descriptions http://cs.stanford.edu/people/karpathy/deepimagesent/
3. Speech Recognition
4. Machine Translation[NIPS15]Sequence to Sequence Learning
with Neural Networks. http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
连接出处:https://blog.csdn.net/lreaderl/article/details/78022724
lstm和gru详解的更多相关文章
- 【pytorch】关于Embedding和GRU、LSTM的使用详解
1. Embedding的使用 pytorch中实现了Embedding,下面是关于Embedding的使用. torch.nn包下的Embedding,作为训练的一层,随模型训练得到适合的词向量. ...
- RNN 与 LSTM 的原理详解
原文地址:https://blog.csdn.net/happyrocking/article/details/83657993 RNN(Recurrent Neural Network)是一类用于处 ...
- (数据科学学习手札39)RNN与LSTM基础内容详解
一.简介 循环神经网络(recurrent neural network,RNN),是一类专门用于处理序列数据(时间序列.文本语句.语音等)的神经网络,尤其是可以处理可变长度的序列:在与传统的时间序列 ...
- tensorflow LSTM+CTC使用详解
最近用tensorflow写了个OCR的程序,在实现的过程中,发现自己还是跳了不少坑,在这里做一个记录,便于以后回忆.主要的内容有lstm+ctc具体的输入输出,以及TF中的CTC和百度开源的warp ...
- torch.nn.LSTM()函数维度详解
123456789101112lstm=nn.LSTM(input_size, hidden_size, num_la ...
- pytorch nn.LSTM()参数详解
输入数据格式:input(seq_len, batch, input_size)h0(num_layers * num_directions, batch, hidden_size)c0(num_la ...
- Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测
Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测 2017年12月13日 17:39:11 机器之心V 阅读数:5931 近日,Artur Suilin 等人发布了 Kaggl ...
- Attention is all you need 论文详解(转)
一.背景 自从Attention机制在提出之后,加入Attention的Seq2Seq模型在各个任务上都有了提升,所以现在的seq2seq模型指的都是结合rnn和attention的模型.传统的基于R ...
- seq2seq模型详解及对比(CNN,RNN,Transformer)
一,概述 在自然语言生成的任务中,大部分是基于seq2seq模型实现的(除此之外,还有语言模型,GAN等也能做文本生成),例如生成式对话,机器翻译,文本摘要等等,seq2seq模型是由encoder, ...
随机推荐
- 二叉搜索树中第K小的元素
给定一个二叉搜索树,编写一个函数 kthSmallest 来查找其中第 k 个最小的元素. 说明:你可以假设 k 总是有效的,1 ≤ k ≤ 二叉搜索树元素个数. 示例 1: 输入: root = [ ...
- ClickHouse
ClickHouse 是俄罗斯的Yandex于2016年开源的列式存储数据库(DBMS),主要用于在线分析处理查询(OLAP),能够使用SQL查询实时生成分析数据报告 1 安装前的准备1.1 Cent ...
- CF1244C The Football Season
题目链接 problem 给定\(n,p,w,d\),求解任意一对\((x,y)\)满足\[xw+yd=p\\ x + y \le n\] \(1\le n\le 10^{12},0\le p\le ...
- import和from...import
目录 一.import 模块名 二.from 模块名 import 具体的功能 三.import和from...import...的异同 一般使用import和from...import...导入模块 ...
- navicat连接mysql报错1251解决方案,从头搭建node + mysql 8.0 (本人亲测有效)
准备学node 好久了 一直没有动手去写,今天突发奇想,然后就安装了一个mysql (找了一个博客跟着步骤去安装的),然后打算用node 写个增删改查. 1.下载mysql安装包 地址: http ...
- js 元素自动点击/执行问题
a标签对于一下两种方式是无效的: <a href="http://qq.com">QQ</a> $('.obj').click(); $('.obj').t ...
- Android常用adb命令总结(二)
adb shell 命令 简单点讲,adb 命令是 adb 这个程序自带的一些命令,而 adb shell 则是调用的 Android 系统中的命令,这些 Android 特有的命令都放在了 Andr ...
- python执行shell实时输出
1.使用readline可以实现 import subprocess def run_shell(shell): cmd = subprocess.Popen(shell, stdin=subproc ...
- make 命令与 Makefile
make 是一个工具程序,通过读取 Makefile 文件,实现自动化软件构建.虽然现代软件开发中,集成开发环境已经取代了 make,但在 Unix 环境中,make 仍然被广泛用来协助软件开发.ma ...
- Python爬取知乎单个问题下的回答
前言 本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. 作者: 努力学习的渣渣哦 PS:如有需要Python学习资料的小伙伴可以加 ...