LSTM推导
LSTM推导
forward propagation
def lstm_cell_forward(xt, a_prev, c_prev, parameters):
"""
Implement a single forward step of the LSTM-cell as described in Figure (4)
Arguments:
xt -- your input data at timestep "t", numpy array of shape (n_x, m).
a_prev -- Hidden state at timestep "t-1", numpy array of shape (n_a, m)
c_prev -- Memory state at timestep "t-1", numpy array of shape (n_a, m)
parameters -- python dictionary containing:
Wf -- Weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
bf -- Bias of the forget gate, numpy array of shape (n_a, 1)
Wi -- Weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
bi -- Bias of the save gate, numpy array of shape (n_a, 1)
Wc -- Weight matrix of the first "tanh", numpy array of shape (n_a, n_a + n_x)
bc -- Bias of the first "tanh", numpy array of shape (n_a, 1)
Wo -- Weight matrix of the focus gate, numpy array of shape (n_a, n_a + n_x)
bo -- Bias of the focus gate, numpy array of shape (n_a, 1)
Wy -- Weight matrix relating the hidden-state to the output, numpy array of shape (n_y, n_a)
by -- Bias relating the hidden-state to the output, numpy array of shape (n_y, 1)
Returns:
a_next -- next hidden state, of shape (n_a, m)
c_next -- next memory state, of shape (n_a, m)
yt_pred -- prediction at timestep "t", numpy array of shape (n_y, m)
cache -- tuple of values needed for the backward pass, contains (a_next, c_next, a_prev, c_prev, xt, parameters)
Note: ft/it/ot stand for the forget/update/output gates, cct stands for the candidate value (c tilda),
c stands for the memory value
"""
# Retrieve parameters from "parameters"
Wf = parameters["Wf"]
bf = parameters["bf"]
Wi = parameters["Wi"]
bi = parameters["bi"]
Wc = parameters["Wc"]
bc = parameters["bc"]
Wo = parameters["Wo"]
bo = parameters["bo"]
Wy = parameters["Wy"]
by = parameters["by"]
# Retrieve dimensions from shapes of xt and Wy
n_x, m = xt.shape
n_y, n_a = Wy.shape
# Concatenate a_prev and xt (≈3 lines)
concat = np.zeros((n_x+n_a,m))
concat[: n_a, :] = a_prev
concat[n_a :, :] = xt
# Compute values for ft, it, cct, c_next, ot, a_next using the formulas given figure (4) (≈6 lines)
ft = sigmoid(np.dot(Wf,concat)+bf)
it = sigmoid(np.dot(Wi,concat)+bi)
cct = np.tanh(np.dot(Wc,concat)+bc)
c_next = ft*c_prev + it*cct
ot = sigmoid(np.dot(Wo,concat)+bo)
a_next = ot*np.tanh(c_next)
# Compute prediction of the LSTM cell (≈1 line)
yt_pred = softmax(np.dot(Wy, a_next) + by)
# store values needed for backward propagation in cache
cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters)
return a_next, c_next, yt_pred, cache
back propagation
def lstm_cell_backward(da_next, dc_next, cache):
"""
Implement the backward pass for the LSTM-cell (single time-step).
Arguments:
da_next -- Gradients of next hidden state, of shape (n_a, m)
dc_next -- Gradients of next cell state, of shape (n_a, m)
cache -- cache storing information from the forward pass
Returns:
gradients -- python dictionary containing:
dxt -- Gradient of input data at time-step t, of shape (n_x, m)
da_prev -- Gradient w.r.t. the previous hidden state, numpy array of shape (n_a, m)
dc_prev -- Gradient w.r.t. the previous memory state, of shape (n_a, m, T_x)
dWf -- Gradient w.r.t. the weight matrix of the forget gate, numpy array of shape (n_a, n_a + n_x)
dWi -- Gradient w.r.t. the weight matrix of the input gate, numpy array of shape (n_a, n_a + n_x)
dWc -- Gradient w.r.t. the weight matrix of the memory gate, numpy array of shape (n_a, n_a + n_x)
dWo -- Gradient w.r.t. the weight matrix of the save gate, numpy array of shape (n_a, n_a + n_x)
dbf -- Gradient w.r.t. biases of the forget gate, of shape (n_a, 1)
dbi -- Gradient w.r.t. biases of the update gate, of shape (n_a, 1)
dbc -- Gradient w.r.t. biases of the memory gate, of shape (n_a, 1)
dbo -- Gradient w.r.t. biases of the save gate, of shape (n_a, 1)
"""
# Retrieve information from "cache"
(a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, parameters) = cache
# Retrieve dimensions from xt's and a_next's shape (≈2 lines)
n_x, m = xt.shape
n_a, m = a_next.shape
# Compute gates related derivatives, you can find their values can be found by looking carefully at equations (7) to (10) (≈4 lines)
dot = da_next * np.tanh(c_next) * ot * (1 - ot)
dcct = (dc_next * it + ot * (1 - np.square(np.tanh(c_next))) * it * da_next) * (1 - np.square(cct))
dit = (dc_next * cct + ot * (1 - np.square(np.tanh(c_next))) * cct * da_next) * it * (1 - it)
dft = (dc_next * c_prev + ot *(1 - np.square(np.tanh(c_next))) * c_prev * da_next) * ft * (1 - ft)
# Compute parameters related derivatives. Use equations (11)-(14) (≈8 lines)
dWf = np.dot(dft,np.concatenate((a_prev, xt), axis=0).T)
dWi = np.dot(dit,np.concatenate((a_prev, xt), axis=0).T)
dWc = np.dot(dcct,np.concatenate((a_prev, xt), axis=0).T)
dWo = np.dot(dot,np.concatenate((a_prev, xt), axis=0).T)
dbf = np.sum(dft, axis=1 ,keepdims = True)
dbi = np.sum(dit, axis=1, keepdims = True)
dbc = np.sum(dcct, axis=1, keepdims = True)
dbo = np.sum(dot, axis=1, keepdims = True)
# Compute derivatives w.r.t previous hidden state, previous memory state and input. Use equations (15)-(17). (≈3 lines)
da_prev = np.dot(parameters['Wf'][:,:n_a].T,dft)+np.dot(parameters['Wi'][:,:n_a].T,dit)+np.dot(parameters['Wc'][:,:n_a].T,dcct)+np.dot(parameters['Wo'][:,:n_a].T,dot)
dc_prev = dc_next*ft+ot*(1-np.square(np.tanh(c_next)))*ft*da_next
dxt = np.dot(parameters['Wf'][:,n_a:].T,dft)+np.dot(parameters['Wi'][:,n_a:].T,dit)+np.dot(parameters['Wc'][:,n_a:].T,dcct)+np.dot(parameters['Wo'][:,n_a:].T,dot)
# parameters['Wf'][:, :n_a].T 每一行的 第 0 到 n_a-1 列的数据取出来
# parameters['Wf'][:, n_a:].T 每一行的 第 n_a 到最后列的数据取出来
# Save gradients in dictionary
gradients = {"dxt": dxt, "da_prev": da_prev, "dc_prev": dc_prev, "dWf": dWf,"dbf": dbf, "dWi": dWi,"dbi": dbi,
"dWc": dWc,"dbc": dbc, "dWo": dWo,"dbo": dbo}
return gradients
LSTM推导的更多相关文章
- 【Deep Learning】RNN LSTM 推导
http://blog.csdn.net/Dark_Scope/article/details/47056361 http://blog.csdn.net/hongmaodaxia/article/d ...
- 循环神经(LSTM)网络学习总结
摘要: 1.算法概述 2.算法要点与推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 6.适用场合 内容: 1.算法概述 长短期记忆网络(Long Short Term Memory ne ...
- 程序猿 tensorflow 入门开发及人工智能实战
tensorflow 中文文档: http://www.tensorfly.cn http://wiki.jikexueyuan.com/project/tensorflow-zh/ tensorfl ...
- 机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸
网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导.下面用自己的记号整理一下. 我之前有个习惯是用下标表示样本序号,这里不能再 ...
- LSTM简介以及数学推导(FULL BPTT)
http://blog.csdn.net/a635661820/article/details/45390671 前段时间看了一些关于LSTM方面的论文,一直准备记录一下学习过程的,因为其他事儿,一直 ...
- 《神经网络的梯度推导与代码验证》之LSTM的前向传播和反向梯度推导
前言 在本篇章,我们将专门针对LSTM这种网络结构进行前向传播介绍和反向梯度推导. 关于LSTM的梯度推导,这一块确实挺不好掌握,原因有: 一些经典的deep learning 教程,例如花书缺乏相关 ...
- lstm bptt推导
深蓝 nlp 180429这个有详细的讲解
- GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现 RNN GRU matlab codes RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着 ...
- RNN求解过程推导与实现
RNN求解过程推导与实现 RNN LSTM BPTT matlab code opencv code BPTT,Back Propagation Through Time. 首先来看看怎么处理RNN. ...
- Theano:LSTM源码解析
最难读的Theano代码 这份LSTM代码的作者,感觉和前面Tutorial代码作者不是同一个人.对于Theano.Python的手法使用得非常娴熟. 尤其是在两重并行设计上: ①LSTM各个门之间并 ...
随机推荐
- 越小越好: Q8-Chat,在英特尔至强 CPU 上体验高效的生成式 AI
大语言模型 (LLM) 正在席卷整个机器学习世界.得益于其 transformer 架构,LLM 拥有从大量非结构化数据 (如文本.图像.视频或音频) 中学习的不可思议的能力.它们在 多种任务类型 上 ...
- vscode 注释快捷键 一键注释和取消注释快捷键
// 注释:ctrl+/ /**/ 注释:alt+shift+a
- Hive执行计划之hive依赖及权限查询和常见使用场景
目录 概述 1.explain dependency的查询与使用 2.借助explain dependency解决一些常见问题 2.1.识别看似等价的SQL代码实际上是不等价的: 2.2 通过expl ...
- JavaScript 显示数据
JavaScript 显示数据 JavaScript 可以通过不同的方式来输出数据: 使用 window.alert() 弹出警告框. 使用 document.write() 方法将内容写到 HTML ...
- 1. Spring相关概念
1. 初始 Spring 1.1 Spring 家族 官网:https://spring.io,从官网我们可以大概了解到: Spring 能做什么:用以开发 web.微服务以及分布式系统等, ...
- Redis基础(二)——列表操作、redis管道、Django中使用redis
Redis列表操作 ''' lpush(name,values) rpush(name, values) 表示从右向左操作 lpushx(name,value) rpushx(name, value) ...
- 从0开发WebGPU渲染引擎:开篇
大家好,本系列会从0开始,开发一个基于WebGPU的路径追踪渲染器,使用深度学习降噪.DLSS等AI技术实现实时渲染:并且基于自研的低代码开发平台,让用户可以通过可视化拖拽的方式快速搭建自定义的Web ...
- Visual Studio Code安装C#开发工具包并编写ASP.NET Core Web应用
前言 前段时间微软发布了适用于VS Code的C#开发工具包(注意目前该包还属于预发布状态但是可以正常使用),因为之前看过网上的一些使用VS Code搭建.NET Core环境的教程看着还挺复杂的就一 ...
- 移动端APP组件化架构实践
前言 对于中大型移动端APP开发来讲,组件化是一种常用的项目架构方式.个人最近几年在工作项目中也一直使用组件化的方式来开发,在这过程中也积累了一些经验和思考.主要是来自在日常开发中使用组件化开发遇到的 ...
- 【Docker】部署Tomcat
搜索镜像 $ docker search 镜像名称:镜像TAG # 如: 没有加TAG,表示默认搜索的是最新版本的tomcat镜像 $ docker search tomcat # 如:搜索 tomc ...