lstm公式推导
http://blog.csdn.net/u010754290/article/details/47167979
导言
在Alex Graves的这篇论文《Supervised Sequence Labelling with Recurrent Neural Networks》中对LSTM进行了综述性的介绍,并对LSTM的Forward Pass和Backward Pass进行了公式推导。
这篇文章将用更简洁的图示和公式一步步对Forward和Backward进行推导,相信读者看完之后能对LSTM有更深入的理解。
如果读者对LSTM的由来和原理存在困惑,推荐DarkScope的这篇博客:《RNN以及LSTM的介绍和公式梳理》
一、LSTM的基础结构
LSTM的结构中每个时刻的隐层包含了多个memory blocks(一般我们采用一个block),每个block包含了多个memory cell,每个memory cell包含一个Cell和三个gate,一个基础的结构示例如下图:
一个memory cell只能产出一个标量值,一个block能产出一个向量。
二、LSTM的前向传播(Forward Pass)
1. 引入
首先我们在上述LSTM的基础结构之上构造时序结构,这样让读者更清晰地看到Recurrent的结构:
这里我们有几个约定:
- 每个时刻的隐层包含一个block
- 每个block包含一个memory cell
下面前向传播我们则从Input开始,逐个求解Input Gate、Forget Gate、Cells Gate、Ouput Gate和最终的Output
这里需要申明的一点,推导过程严格按照上述图示LSTM的结构;论文中对相较于该文章的推导过程会有增加一些项,在每一个公式不一致的地方我都会有相应说明。
2. Input Gate(ι) 的计算
Input Gate接受两个输入:
- 当前时刻的Input作为输入:xt
- 上一时刻同一block内所有Cell作为输入:st−1c
该案例中每层仅有单个Block、单个cemory cell,可以忽略∑Cc=1,以下Forget Gate和Output Gate做相同处理。
最终Input Gate的输出为:
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1h作为输入,论文中atι会增加一项∑Hh=1ωhιbt−1h。
3. Forget Gate(ϕ) 的计算
Forget Gate接受两个输入:
- 当前时刻的Input作为输入:xt
- 上一时刻同一block内所有Cell作为输入:st−1c
最终Forget Gate的输出为:
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1h作为输入,论文中atϕ会增加一项∑Hh=1ωhϕbt−1h。
4. Cell(c) 的计算
Cell的计算稍有些复杂,接受两个输入:
- Input Gate和Input输入的乘积
- Forget Gate和上一时刻对应Cell输出的乘积
最终Cell的输出为:
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1h作为输入,论文中atc会增加一项∑Hh=1ωhcbt−1h。
5. Output Gate(ω) 的计算
Output Gate接受两个输入:
- 当前时刻的Input作为输入:xt
- 当前时刻同一block内所有Cell作为输入:stc
这里Output Gate接受“当前时刻Cell的输出”而不是“上一时刻Cell的输出”,是由于此时Cell的结果已经产出,我们控制Output Gate的输出直接采用Cell当前的结果就行了,无须使用上一时刻。
最终Output Gate的输出为:
这里Cell还可以接受上一个时刻中其他gate链接过来的边,论文中atϕ会增加一项∑Hh=1ωhϕbt−1h,这里H是泛指t-1时刻的Cell或三个Gate。
6. Cell Output(c) 的计算
Cell Output的计算即将Output Gate和Cell做乘积即可。
最终Cell Output为:
7. 小结
至此,整个Block从Input到Output整个Forward Pass已经结束,其中涉及三个Gate和中间Cell的计算,需要注意的是三个Gate使用的激活函数是f,而Input的激活函数是g、Cell输出的激活函数是h。
这里读者需要注意,在整个计算过程中,当前时刻的三个Gate均可以从上一时刻的任意Gate中接受输入,在公式中存在体现,但是在图示中并未画出相应的边。我们可以认为只有上一时刻的Cell才和当前时刻的Cell或三个Gate相连。
三、LSTM的反向传播(Backward Pass)
1. 引入
此处在论文中使用“Backward Pass”一词,但其实即Back Propagation过程,利用链式求导求解整个LSTM中每个权重的梯度。
2. 损失函数的选择
为了通用起见,在此我们仅展示多分类问题的损失函数的选择,对于网络的最终输出我们利用softmax方程计算结果属于某一类的概率(此时结果属于k个类别的概率和为1)。
注意,yk对ak的偏导为∂yk′∂ak=ykδkk′−ykyk′(δkk′当k==k′时为1,其他为0)
其中,对于网络输出a1,a2,...对应我们可以得到p(C1|x),p(C2|x),...,即给定输入x输出类别为C1,C2,...的概率。
这样损失函数(Loss Function)就很好定义了:对于k∈1,2,...,K,网络输出的类别为k概率为yk,而真实值zk:
3. 权重的更新
对于神经网络中的每一个权重,我们都需要找到对应的梯度,从而通过不断地用训练样本进行随机梯度下降找到全局最优解,那么首先我们需要知道哪些权重需要更新。
一般层次分明的神经网络有input层、hidden层和output层,层与层之间的权重比较直观;但在LSTM中通过公式才能找到对应的权重,和图示中的边并不是一一对应,下面我将LSTM的单个Block中需要更新的权重在图示上标示了出来:
为了方便起见,这里需要申明的是:我们仅考虑上一时刻的Cell仅和当前时刻的Cell和三个Gate相连。
2. Cell Output的梯度
首先我们计算每一个输出类别的梯度:
也即每一个输出类别的梯度仅和其预测值和真实值相关,这样对于Cell Output的梯度则可以通过链式求导法则推导出来:
由于Output还可以连接下一个时刻的一个Cell、三个Gate,那么下一个时刻的一个Cell、三个Gate的梯度则可以传递回当前时刻Output,所以在论文中存在额外项∑Gg=1ωcgδt+1g,为简便起见,公式和图示中未包含。
3. Output Gate的梯度
根据链式求导法则,Output Gate的梯度可以由以下公式推导出来:
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Output Gate的梯度写成了f′(atw)∑Cc=1ϵtch(stc),但推导过程一致。推导过程见下图,说明梯度汇总到单个Gate中:
4. Cell的梯度
细心的读者在这里会发现,Cell的计算结构和普遍的神经网络不太一样,让我们首先来回顾一下Cell部分的Forward计算过程:
输入数据贡献给atc,而Cell同时能够接受Input Gate和Forget Gate的输入。
这样梯度就直接从Cell向下传递:
在这里,我们定义States,由于Cell的梯度可以由以下几个计算单元传递回来:
- 当前时刻的Cell Output
- 下一个时刻的Cell
- 下一个时刻的Input Gate
- 下一个时刻的Output Gate
那么States可以这样求解,上面1~4个能够回传梯度的计算单元和下面公式中一一对应:
那么:
细心的读者会发现,论文中∂(x,z)∂btc并没有求和,这里作者持保留态度,应该存在求和项。
同时由于Cell可以连接到下一个时刻的Forget Gate、Output Gate和Input Gate,那么下一时刻的这三个Gate则可以将梯度传播回来,所以在论文中我们会发现ϵts拥有这三项:bt+1ϕϵt+1s、ωclδt+1ι和ωcϕδt+1ϕ。
5. Forget Gate的梯度
Forget Gate的梯度计算就比较简单明了:
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Forget Gate的梯度写成了f′(atϕ)∑Cc=1st−1cϵts,但推导过程一致,说明梯度汇总到单个Gate中。
6. Input Gate的梯度
Input Gate的梯度计算如下:
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Input Gate的梯度写成了f′(atι)∑Cc=1g(atc)ϵts,但推导过程一致,说明梯度汇总到单个Gate中。
7. 小结
至此,所有的梯度求解已经结束,同样我们将这个Backward Pass的所有公式列出来:
剩下的事情即利用梯度去更新每个权重:
其中mΔωn−1为上一次权重的更新值,且m∈[0,1];而∂∂ωn即上面我们求到的每一个梯度。
例如每次更新ωiϕ的Δ量即:
其中δtϕ即Forget Gate的梯度。
三、总结
以上就是LSTM中的前向和反向传播的公式推导,在这里作者仅以最简单的单个Cell的场景进行示例。
在实际工程实践中,常常会涉及到同一时刻多个Cell且互相之间的Gate存在连接,同时上一个时刻或下一个时刻的Cell和三个Gate之间同样存在复杂的连接关系。
但如果读者能够明晰上述的推导过程,那么无论多复杂都能够迎刃而解了。
lstm公式推导的更多相关文章
- RNN LSTM 介绍
[RNN以及LSTM的介绍和公式梳理]http://blog.csdn.net/Dark_Scope/article/details/47056361 [知乎 对比 rnn lstm 简单代码] ...
- RNN/LSTM/GRU/seq2seq公式推导
概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...
- 学习笔记CB012: LSTM 简单实现、完整实现、torch、小说训练word2vec lstm机器人
真正掌握一种算法,最实际的方法,完全手写出来. LSTM(Long Short Tem Memory)特殊递归神经网络,神经元保存历史记忆,解决自然语言处理统计方法只能考虑最近n个词语而忽略更久前词语 ...
- 详解LSTM
https://blog.csdn.net/class_brick/article/details/79311148 今天的内容有: LSTM 思路 LSTM 的前向计算 LSTM 的反向传播 关于调 ...
- RNN以及LSTM的介绍和公式梳理
前言 好久没用正儿八经地写博客了,csdn居然也有了markdown的编辑器了,最近花了不少时间看RNN以及LSTM的论文,在组内『夜校』分享过了,再在这里总结一下发出来吧,按照我讲解的思路,理解RN ...
- LSTM简介以及数学推导(FULL BPTT)
http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...
- RNN(Recurrent Neural Networks)公式推导和实现
RNN(Recurrent Neural Networks)公式推导和实现 http://x-algo.cn/index.php/2016/04/25/rnn-recurrent-neural-net ...
- Caffe2:使用Caffe构建LSTM网络
前言: 一般所称的LSTM网络全叫全了应该是使用LSTM单元的RNN网络. 原文:(Caffe)LSTM层分析 入门篇:理解LSTM网络 LSTM的官方简介: http://deeplearning. ...
- pytorch nn.LSTM()参数详解
输入数据格式:input(seq_len, batch, input_size)h0(num_layers * num_directions, batch, hidden_size)c0(num_la ...
随机推荐
- 对java多线程的一些浅浅的理解
作为一名JAVA初学者,前几天刚刚接触多线程这个东西,有了些微微的理解想写下来(不对的地方请多多包涵并指教哈). 多线程怎么写代码就不说了,一搜一大堆.说说多线程我认为最难搞的地方,就是来回释放锁以及 ...
- 大咖分享 | 一文解锁首届云创大会干货——上篇(文末附演讲ppt文件免费下载)
日,第一届网易云创大会在杭州国际博览中心举办,本次大会由杭州滨江区政府和网易主办,杭州市两创示范工作领导小组办公室协办,网易云承办,以"商业匠心.技术创新"为主题,致力于打通技术创 ...
- Jquery+Ajax+asp.net+sqlserver-编写的通用邮件管理(源码)
开始 邮件管理通常用在各个内部系统中,为了方便快捷的使用现有的代码开发一个邮件管理系统而诞生的. 准备条件 这是我的设计表结构,大家一看就懂了 --邮件接收表CREATE TABLE [dbo]. ...
- 用上GIT你一定会爱上他
前言 Git是一个开源的分布式版本控制系统,用以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控 ...
- python redis中blpop和lpop的区别
python redis 中blpop返回的是元组对象,因此返回的时候注意 lpop返回的是对象
- [转]查看Linux版本信息
一.查看Linux内核版本命令(两种方法): 1.cat /proc/version [root@S-CentOS home]# cat /proc/version Linux version 2.6 ...
- 【bzoj4319】cerc2008 Suffix reconstruction 贪心
题目描述 话说练习后缀数组时,小C 刷遍 poj 后缀数组题, 各类字符串题闻之丧胆.就在准备对敌方武将发出连环杀时,对方一记无中生有,又一招顺手牵羊,小C 程序中的原字符数组就被牵走了.幸运的是,小 ...
- [转] Makefile 基础 (8) —— Makefile 隐含规则
该篇文章为转载,是对原作者系列文章的总汇加上标注. 支持原创,请移步陈浩大神博客:(最原始版本) http://blog.csdn.net/haoel/article/details/2886 我转自 ...
- d3 svg简单学习
矩形 <rect x="/> 圆形 <circle cx="/> 椭圆 <ellipse cx="/> 线 <line x1=& ...
- Linq技巧2——限制返回数据中的继承类型
假如有像下面这样的一个模型, 怎样在查询时仅仅需要的Cars呢? 这样的几个继承关系的实体中,查询时Where 条件可以加入OfType<SubType>(),你可以这样来写: var o ...