关于 RNN 循环神经网络的反向传播求导

本文是对 RNN 循环神经网络中的每一个神经元进行反向传播求导的数学推导过程,下面还使用 PyTorch 对导数公式进行编程求证。

RNN 神经网络架构

一个普通的 RNN 神经网络如下图所示:

其中 \(x^{\langle t \rangle}\) 表示某一个输入数据在 \(t\) 时刻的输入;\(a^{\langle t \rangle}\) 表示神经网络在 \(t\) 时刻时的hidden state,也就是要传送到 \(t+1\) 时刻的值;\(y^{\langle t \rangle}\) 则表示在第 \(t\) 时刻输入数据传入以后产生的预测值,在进行预测或 sampling 时 \(y^{\langle t \rangle}\) 通常作为下一时刻即 \(t+1\) 时刻的输入,也就是说 \(x^{\langle t \rangle}=\hat{y}^{\langle t \rangle}\) ;下面对数据的维度进行说明。

  • 输入: \(x\in\mathbb{R}^{n_x\times m\times T_x}\) 其中 \(n_x\) 表示每一个时刻输入向量的长度;\(m\) 表示数据批量数(batch);\(T_x\) 表示共有多少个输入的时刻(time step)。
  • hidden state:\(a\in\mathbb{R}^{n_a\times m\times T_x}\) 其中 \(n_a\) 表示每一个 hidden state 的长度。
  • 预测:\(y\in\mathbb{R}^{n_y\times m\times T_y}\) 其中 \(n_y\) 表示预测输出的长度;\(T_y\) 表示共有多少个输出的时刻(time step)。

RNN 神经元

下图所示的是一个特定的 RNN 神经元:

上图说明了在第 \(t\) 时刻的神经元中,数据的输入 \(x^{\langle t \rangle}\) 和上一层的 hidden state \(a^{\langle t \rangle}\) 是如何经过计算得到下一层的 hidden state 和预测输出 \(\hat{y}^{\langle t \rangle}\) 。

下面是对五个参数的维度说明:

  • \(W_{aa}\in\mathbb{R}^{n_a\times n_a}\)
  • \(W_{ax}\in\mathbb{R}^{n_a\times n_x}\)
  • \(b_a\in\mathbb{R}^{n_a\times 1}\)
  • \(W_{ya}\in\mathbb{R}^{n_y\times n_a}\)
  • \(b_y\in\mathbb{R}^{n_y\times 1}\)

计算 \(t\) 时刻的 hidden state \(a^{\langle t \rangle}\) :

\[\begin{split}
z1^{\langle t \rangle} &= W_{aa} a^{\langle t-1 \rangle} + W_{ax} x^{\langle t \rangle} + b_a\\
a^{\langle t \rangle} &= \tanh(z1^{\langle t \rangle})
\end{split}
\]

预测 \(t\) 时刻的输出 \(\hat{y}^{\langle t \rangle}\) :

\[\begin{split}
z2^{\langle t \rangle} &= W_{ya} a^{\langle t \rangle} + b_y\\
\hat{y}^{\langle t \rangle} &= softmax(z2^{\langle t \rangle}) = \frac{e^{z2^{\langle t \rangle}}}{\sum_{i=1}^{n_y}e^{z2_i^{\langle t \rangle}}}
\end{split}
\]

RNN 循环神经网络反向传播

在当今流行的深度学习编程框架中,我们只需要编写一个神经网络的结构和负责神经网络的前向传播,至于反向传播的求导和参数更新,完全由框架搞定;即便如此,我们在学习阶段也要自己动手证明一下反向传播的有效性。

RNN 神经元的反向传播

下图是 RNN 神经网络中的一个基本的神经元,图中标注了反向传播所需传来的参数和输出等。

就如一个全连接的神经网络一样,损失函数 \(J\) 的导数通过微积分的链式法则(chain rule)反向传播到每一个时间轴上。

为了方便,我们将损失函数关于神经元中参数的偏导符号简记为 \(\mathrm{d}\mathit{parameters}\) ;例如将 \(\frac{\partial J}{\partial W_{ax}}\) 记为 \(\mathrm{d}W_{ax}\) 。

上图的反向传播的实现并没有包括全连接层和 Softmax 层。

反向传播求导

计算损失函数关于各个参数的偏导数之前,我们先引入一个计算图(computation graph),其演示了一个 RNN 神经元的前向传播和如何利用计算图进行链式法则的反向求导。

因为当进行反向传播求导时,我们需要将整个时间轴的输入全部输入之后,才可以从最后一个时刻开始往前传进行反向传播,所以我们假设 \(t\) 时刻就为最后一个时刻 \(T_x\) 。

如果我们想要先计算 \(\frac{\partial\ell}{\partial W_{ax}}\) 所以我们可以从计算图中看到,反向传播的路径:

我们需要按部就班的分别对从 \(W_{ax}\) 计算到 \(\ell\) 一路相关的变量进行求偏导,利用链式法则,将红色路线上一路的偏导数相乘到一起,就可以求出偏导数 \(\frac{\partial\ell}{\partial W_{ax}}\) ;所以我们得到:

\[\begin{split}
\frac{\partial\ell}{\partial W_{ax}}
&=
\frac{\partial\ell}{\partial\ell^{\langle t\rangle}}
{\color{Red}{
\frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}}
\frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}}
}}
\frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}}
\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}
\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}
\end{split}
\]

在上面的公式中,我们仅需要分别求出每一个偏导即可,其中红色的部分就是关于 \(\mathrm{Softmax}\) 的求导,关于 \(\mathrm{Softmax}\) 求导的推导过程,可以看本人的另一篇博客: 关于 Softmax 回归的反向传播求导数过程

关于 \(\mathrm{tanh}\) 的求导公式如下:

\[\frac{\partial \tanh(x)} {\partial x} = 1 - \tanh^2(x)
\]

所以上面的式子就得到:

\[\begin{split}
\frac{\partial\ell}{\partial W_{ax}}
&=
\frac{\partial\ell}{\partial\ell^{\langle t\rangle}}
{\color{Red}{
\frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}}
\frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}}
}}
\frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}}
\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}
\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\\
&=
{\color{Red}{
(\hat{y}^{\langle t\rangle}-y^{\langle t\rangle})
}}
W_{ya}
(1-\tanh^2(z1^{\langle t\rangle}))
x^{\langle t\rangle}
\end{split}
\]

我们就可以得到在最后时刻 \(t\) 参数 \(W_{ax}\) 的偏导数。

关于上面式子中的偏导数的计算,除了标量对矩阵的求导,在后面还包括了两个一个矩阵或向量对另一个矩阵或向量中的求导,实际上这是非常麻烦的一件事。

比如在计算 \(\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\) 偏导数的时候,我们发现 \(z1^{\langle t\rangle}\) 是一个 \(\mathbb{R}^{n_a\times m}\) 的矩阵,而 \(W_{ax}\) 则是一个 \(\mathbb{R}^{n_a\times n_x}\) 的矩阵,这一项就是一个矩阵对另一个矩阵求偏导,如果直接对其求导我们将会得到一个四维的矩阵 \(\mathbb{R}^{n_a\times n_x\times n_a\times m}\) (雅可比矩阵 Jacobian matrix);只不过这个高维矩阵中偏导数的值有很多 \(0\) 。

在神经网络中,如果直接将这个高维矩阵直接生搬硬套进梯度下降里更新参数是不可行,因为我们需要得到的梯度是关于自变量同型的向量或矩阵而且我们还要处理更高维度的矩阵的乘法;所以我们需要将结果进行一定的处理得到我们仅仅需要的信息。

一般在深度学习框架中都会有自动求梯度的功能包,这些包(比如 PyTorch )中就只允许一个标量对向量或矩阵求导,其他情况是不允许的,除非在反向传播的函数里传入一个同型的权重向量或矩阵才可以得到导数。

我们先简单求出一个偏导数 \(\frac{\partial\ell}{\partial W_{ax}}\) 我们下面使用 PyTorch 中的自动求梯度的包进行验证我们的公式是否正确。

  1. import torch
  1. # 这是神经网络中的一些架构的参数
  2. n_x = 6
  3. n_y = 6
  4. m = 1
  5. T_x = 5
  6. T_y = 5
  7. n_a = 3
  1. # 定义所有参数矩阵
  2. # requires_grad 为 True 表明在涉及这个变量的运算时建立计算图
  3. # 为了之后反向传播求导
  4. W_ax = torch.randn((n_a, n_x), requires_grad=True)
  5. W_aa = torch.randn((n_a, n_a), requires_grad=True)
  6. ba = torch.randn((n_a, 1), requires_grad=True)
  7. W_ya = torch.randn((n_y, n_a), requires_grad=True)
  8. by = torch.randn((n_y, 1), requires_grad=True)
  1. # t 时刻的输入和上一时刻的 hidden state
  2. x_t = torch.randn((n_x, m), requires_grad=True)
  3. a_prev = torch.randn((n_a, m), requires_grad=True)
  4. y_t = torch.randn((n_y, m), requires_grad=True)
  1. # 开始模拟一个神经元 t 时刻的前向传播
  2. # 从输入一直到计算出 loss
  3. z1_t = torch.matmul(W_ax, x_t) + torch.matmul(W_aa, a_prev) + ba
  4. z1_t.retain_grad()
  5. a_t = torch.tanh(z1_t)
  6. a_t.retain_grad()
  7. z2_t = torch.matmul(W_ya, a_t) + by
  8. z2_t.retain_grad()
  9. y_hat = torch.exp(z2_t) / torch.sum(torch.exp(z2_t), dim=0)
  10. y_hat.retain_grad()
  11. loss_t = -torch.sum(y_t * torch.log(y_hat), dim=0)
  12. loss_t.retain_grad()
  1. # 对最后的 loss 标量开始进行反向传播求导
  2. loss_t.backward()
  1. # 我们就可以得到 W_ax 的导数
  2. # 存储在后缀 _autograd 变量中,表明是由框架自动求导得到的
  3. W_ax_autograd = W_ax.grad
  1. # 查看框架计算得到的导数
  2. W_ax_autograd
  1. tensor([[ 0.5252, 1.1938, -0.2352, 1.1571, -1.0168, 0.3195],
  2. [-1.0536, -2.3949, 0.4718, -2.3213, 2.0398, -0.6410],
  3. [-0.0316, -0.0717, 0.0141, -0.0695, 0.0611, -0.0192]])
  1. # 我们对自己推演出的公式进行手动计算导数
  2. # 存储在后缀 _manugrad 变量中,表明是手动由公式计算得到的
  3. W_ax_manugrad = torch.matmul(torch.matmul((y_hat - y_t).T, W_ya).T * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
  4. #torch.matmul(torch.matmul(W_ya.T, y_hat - y_t) * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
  1. # 输出手动计算的导数
  2. W_ax_manugrad
  1. tensor([[ 0.5195, 1.1809, -0.2327, 1.1447, -1.0058, 0.3161],
  2. [-1.0195, -2.3172, 0.4565, -2.2461, 1.9737, -0.6202],
  3. [-0.0309, -0.0703, 0.0138, -0.0681, 0.0599, -0.0188]],
  4. grad_fn=<MmBackward>)
  1. # 查看两种求导结果的之差的 L2 范数
  2. torch.norm(W_ax_manugrad - W_ax_autograd)
  1. tensor(0.1356, grad_fn=<CopyBackwards>)

通过上面的编程输出可以看到,我们手动计算的导数和框架自己求出的导数虽然有一定的误差,但是一一对照可以大体看到我们手动求出来的导数大体是对的,并没有说错的非常离谱。

但上面只是当 \(t=T_x\) 即 \(t\) 时刻是最后一个输入单元的时候,也就是说所求的关于 \(_W{ax}\) 的导数只是全部导数的一部分,因为参数共享,所以每一时刻的神经元都有对 \(W_{ax}\) 的导数,所以需要将所有时刻的神经元关于 \(W_{ax}\) 的导数全部加起来。

若 \(t\) 不是最后一时刻,可能是神经网络里的中间的某一时刻的神经元;也就是说,在进行反向传播的时候,想要求 \(t\) 时刻的导数,就得等到 \(t+1\) 时刻的导数值传进来,然后根据链式法则才可以计算当前时刻参数的导数。

下面是一个简易的计算图,只绘制出了 \(W_ax\) 到 \(\ell\) 的计算中,共涉及到哪些变量(在整个神经网络中的 \(W_{ax}\) 的权重参数是共享的):

下面使用一个视频展示整个神经网络中从 \(W_{ax}\) 到一个数据批量的损失值 \(\ell\) 的大体流向:

计算完 \(\ell\) 之后就可以计算 \(\frac{\partial\ell}{\partial W_{ax}}\) 的导数值,但是 RNN 神经网络的反向传播区别于全连接神经网络的。

然后,我们演示一下如何进行反向传播的,注意看每一个时刻的 \(a^{\langle t\rangle}\) 的计算都是等 \(a^{\langle t+1\rangle}\) 的导数值传进来才进行计算的;同样地,\(W_{ax}\) 导数的计算也不是一步到位的,也是需要等到所有时刻的 \(a\) 的值全部传到才计算完。

所以对于神经网络中间某一个单元 \(t\) 我们有:

\[\begin{split}
\frac{\partial\ell}{\partial W_{ax}}
&=
{\color{Red}{
\left(
\frac{\partial\ell}{\partial a^{\langle t\rangle}}
+\frac{\partial\ell}{\partial z1^{\langle t+1\rangle}}
\frac{\partial z1^{\langle t+1\rangle}}{\partial a^{\langle t\rangle}}
\right)
}}
\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}
\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}
\end{split}
\]

关于红色的部分的意思是需要等到 \(t+1\) 时刻的导数值传进来,然后才可以进行对 \(t+1\) 时刻关于当前时刻 \(t\) 的参数求导,最后得到参数梯度的一个分量。其实若仔细展开每一个偏导项,就像是一个递归一样,每次求某一时刻的导数总是要从最后一时刻往前传到当前时刻才可以进行。

多元复合函数的求导法则

如果函数 \(u=\varphi(t)\) 及 \(v=\psi(t)\) 都在点 \(t\) 可导,函数 \(z=f(u,v)\) 在对应点 \((u,v)\) 具有连续偏导数,那么复合函数 \(z=f[\varphi(t),\psi(t)]\) 在点 \(t\) 可导,且有

\[\frac{\mathrm{d}z}{\mathrm{d}t}=\frac{\partial z}{\partial u}\frac{\mathrm{d}u}{\mathrm{d}t}+\frac{\partial z}{\partial v}\frac{\mathrm{d}v}{\mathrm{d}t}
\]

下面使用一张计算图说明 \(a^{\langle t\rangle}\) 到 \(\ell\) 的计算关系。

也就是说第 \(t\) 时刻 \(\ell\) 关于 \(a^{\langle t\rangle}\) 的导数是由两部分相加组成,也就是说是由两条路径反向传播,这两条路径分别是 \(\ell\to\ell^{\langle t\rangle}\to\hat{y}^{\langle t\rangle}\to z2^{\langle t\rangle}\to a^{\langle t\rangle}\) 和 \(\ell\to\ell^{\langle t+1\rangle}\to\hat{y}^{\langle t+1\rangle}\to z2^{\langle t+1\rangle}\to a^{\langle t+1\rangle}\to z1^{\langle t+1\rangle}\to a^{\langle t\rangle}\) ,我们将这两条路径导数之和使用 \(\mathrm{d}a_{\mathrm{next}}\) 表示。

所以我们可以得到在中间某一时刻的神经单元关于 \(W_{ax}\) 的导数为:

\[\frac{\partial\ell}{\partial W_{ax}}=\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t \rangle}\right)\right) x^{\langle t \rangle T}
\]

通过同样的方法,我们就可以得到其它参数的导数:

\[\begin{align}
\frac{\partial\ell}{\partial W_{aa}}
&=
\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right) a^{\langle t-1 \rangle T}\\
\frac{\partial\ell}{\partial b_a}
& = \sum_{batch}\left( da_{next} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right)\\
\end{align}
\]

除了传递参数的导数,在第 \(t\) 时刻还需要传送 \(\ell\) 关于 \(z1^{\langle t\rangle}\) 的导数到 \(t-1\) 时刻,将需要传送到上一时刻的导数记作为 \(\mathrm{d}a_{\mathrm{prev}}\) 我们得到:

\[\begin{split}
\mathrm{d}a_{\mathrm{prev}} &= \mathrm{d}a_\mathrm{next}\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}\frac{\partial z1^{\langle t\rangle}}{\partial a^{\langle t-1\rangle}}\\
&= { W_{aa}}^T\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right)
\end{split}
\]

可以看到,一个循环神经网络的反向传播实际上是非常复杂的,因为每一时刻的神经元都与参数有计算关系,所以反向传播时的路径非常杂乱,其中还涉及到了高维的矩阵,所以在计算时需要对高维矩阵进行一定的矩阵代数转换才方便导数和更新参数的计算。

关于 RNN 循环神经网络的反向传播求导的更多相关文章

  1. 关于 Softmax 回归的反向传播求导数过程

    对于 \(Softmax\) 回归的正向传播非常简单,就是对于一个输入 \(X\) 对每一个输入标量 \(x_i\) 进行加权求和得到 \(Z\) 然后对其做概率归一化. Softmax 示意图 下面 ...

  2. 《神经网络的梯度推导与代码验证》之vanilla RNN前向和反向传播的代码验证

    在<神经网络的梯度推导与代码验证>之vanilla RNN的前向传播和反向梯度推导中,我们学习了vanilla RNN的前向传播和反向梯度求导,但知识仍停留在纸面.本篇章将基于深度学习框架 ...

  3. CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的内部网络结构有什么区别?

    https://www.zhihu.com/question/34681168 CNN(卷积神经网络).RNN(循环神经网络).DNN(深度神经网络)的内部网络结构有什么区别?修改 CNN(卷积神经网 ...

  4. Recurrent Neural Networks(RNN) 循环神经网络初探

    1. 针对机器学习/深度神经网络“记忆能力”的讨论 0x1:数据规律的本质是能代表此类数据的通用模式 - 数据挖掘的本质是在进行模式提取 数据的本质是存储信息的介质,而模式(pattern)是信息的一 ...

  5. AI之旅(7):神经网络之反向传播

    前置知识   求导 知识地图   神经网络算法是通过前向传播求代价,反向传播求梯度.在上一篇中介绍了神经网络的组织结构,逻辑关系和代价函数.本篇将介绍如何求代价函数的偏导数(梯度). 梯度检测   在 ...

  6. 神经网络之反向传播算法(BP)公式推导(超详细)

    反向传播算法详细推导 反向传播(英语:Backpropagation,缩写为BP)是"误差反向传播"的简称,是一种与最优化方法(如梯度下降法)结合使用的,用来训练人工神经网络的常见 ...

  7. NLP教程(3) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/36 本文地址:http://www.showmeai.tech/article-det ...

  8. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

  9. TensorFlow框架(6)之RNN循环神经网络详解

    1. RNN循环神经网络 1.1 结构 循环神经网络(recurrent neural network,RNN)源自于1982年由Saratha Sathasivam 提出的霍普菲尔德网络.RNN的主 ...

随机推荐

  1. VS Code 搭建 Rust 开发环境

    VS Code 搭建 Rust 开发环境 上一篇文章安装和配置好了 Rust 环境后,我们是使用的是简单的文本工具编写 Hello World 入门代码,但是为了提高我们的学习效率,下面安利大家 VS ...

  2. js- 判断属性是否 属于该对象 hasOwnProperty()

    var obj ={ name:'suan', sex :'male', age:150, height:185, characeter:true, _proto_:{ lastName:'susan ...

  3. PHP文件上传、错误处理

    说明 这篇是针对之前php知识的补充内容 目录 说明 1. PHP目录处理函数 2. PHP文件权限设置 3. PHP文件路径函数 4. PHP实现文件留言本 5.PHP文件上传 1. php文件上传 ...

  4. 安全声明标记语言SAML2.0初探

    目录 简介 SAML的构成 SAML的优势 SAML是怎么工作的 SP redirect request; IdP POST response SP POST Request; IdP POST Re ...

  5. MISC-吹着贝斯扫二维码

    题目 [安洵杯 2019]吹着贝斯扫二维码 解压附件,有36个文件和一个压缩包,压缩包带密码和备注 分析 文件类型 随便打开一个不明文件,是jpg图片啊(FF D8 FF) 改一个试试,有一个小块二维 ...

  6. Mybatis学习-日志与分页

    日志 为什么需要日志 如果一个数据库操作出现了异常,需要排错,那么日志就是最好的助手 Mybatis通过使用内置的日志工厂提供日志功能,有一下几种实现方式: SLF4J Apache Commons ...

  7. 微服务开发的最大痛点-分布式事务SEATA入门简介

    前言 在微服务开发中,存在诸多的开发痛点,例如分布式事务.全链路跟踪.限流降级和服务平滑上下线等.而在这其中,分布式事务是最让开发者头痛的.那分布式事务是什么呢? 分布式事务就是指事务的参与者.支持事 ...

  8. js上 十、循环语句-1:

    十.循环语句-1: 非常之重要. 作用:重复执行一段代码 ü while ü do...while ü for 它们的相同之处,都能够实现循环. 不同的地方,格式不一样,使用的场景略有不同. #10- ...

  9. BullseyeCoverage——C/C++代码覆盖率分析工具

    安装 官网下载地址: https://www.bullseye.com/cgi-bin/download 安装过程中需要License, 可以申请评估版本的key或者直接从官网渠道购买. Ubuntu ...

  10. 说说在 Python 中如何快速复制序列

    1 基本用法 把序列乘以一个整数,就会产生一个新序列.这个新序列是原始序列复制了整数份,然后再拼接起来的结果. l=[1,2,3] l2=l * 3 logging.info('l2 -> %s ...