近些年来,随着深度学习的崛起,RNN模型也变得非常热门。如果把RNN模型按照时间轴展开,它也类似其它的深度神经网络模型结构。因此,我们可以参照已有的方法训练RNN模型。

    现在最流行的一种RNN模型是LSTM(长短期记忆)网络模型。

    尽管我们可以借助Tensorflow、Torch、Theano等深度学习库轻松地训练模型,而不再需要推导反向传播的过程,但是逐步推导LSTM模型的梯度并用反向传播算法来实现,对我们深刻地理解模型是大有裨益的。

    因此,我们首先按照LSTM的公式实现正向传播计算过程,然后推导网络模型的梯度计算过程,最后用numpy来实现模型的求解。

LSTM正向传播

用代码可以表示为:

  1. H = 128 # LSTM 层神经元的数量
  2. D = ... # 输入数据的维度 == 词表的大小
  3. Z = H + D # 因为需要把LSTM的状态与输入数据拼接
  1. model = dict(
  2. Wf=np.random.randn(Z, H) / np.sqrt(Z / 2.),
  3. Wi=np.random.randn(Z, H) / np.sqrt(Z / 2.),
  4. Wc=np.random.randn(Z, H) / np.sqrt(Z / 2.),
  5. Wo=np.random.randn(Z, H) / np.sqrt(Z / 2.),
  6. Wy=np.random.randn(H, D) / np.sqrt(D / 2.),
  7. bf=np.zeros((1, H)),
  8. bi=np.zeros((1, H)),
  9. bc=np.zeros((1, H)),
  10. bo=np.zeros((1, H)),
  11. by=np.zeros((1, D))
  12. )
  1.  
  2. 在上面,我们定义了LSTM单元的结构。上述公式需要注意的一点是,我们把隐藏层上一步的状态h与当前的输入x相连接,因此LSTM单元的输入是 Z = H + D。另外,我们LSTM单元的输出层有H个神经元,因此每个权重矩阵的维度是 ZxH,偏置向量的维度是 1xH
    W

y

  1. b

y

  1. 略有不同,这两项是全连接层的参数,它们的下一级是softmax层。最终的输出结果将是词表中每个词语出现的概率分布,维度为 1xD。因此,W

y

  1. 的维度必须是 HxDb

y

  1. 的维度必须是 1xD
  1. def lstm_forward(X, state):
  2. m = model
  3. Wf, Wi, Wc, Wo, Wy = m['Wf'], m['Wi'], m['Wc'], m['Wo'], m['Wy']
  4. bf, bi, bc, bo, by = m['bf'], m['bi'], m['bc'], m['bo'], m['by']
  5. h_old, c_old = state
  6. # One-hot 编码
  7. X_one_hot = np.zeros(D)
  8. X_one_hot[X] = 1.
  9. X_one_hot = X_one_hot.reshape(1, -1)
  10. # 上一步状态与当前输入值连接
  11. X = np.column_stack((h_old, X_one_hot))
  1.   hf = sigmoid(X @ Wf + bf)
  2. hi = sigmoid(X @ Wi + bi)
  3. ho = sigmoid(X @ Wo + bo)
  4. hc = tanh(X @ Wc + bc)
  5. c = hf * c_old + hi * hc
  6. h = ho * tanh(c)
  7. y = h @ Wy + by
  8. prob = softmax(y)
  9. cache = ... # 存储所有的中间变量结果
  10. return prob, cache

上面的代码表示了单个LSTM单元的前向传播过程,与公式表示的基本一致,多了one-hot编码的步骤。

LSTM反向传播

接下来,我们进入到本篇文章的要点:LSTM反向传播计算。我们假设可以调用函数计算sigmoid和tanh函数的导数。

  1. def lstm_backward(prob, y_train, d_next, cache):
  2. # 取出前向传播步骤中存储的中间状态变量
  3. ... = cache
  4. dh_next, dc_next = d_next
  5. # Softmax loss gradient
  6. dy = prob.copy()
  7. dy[1, y_train] -= 1.

  8. # 隐藏层到输出层的导数
  9. dWy = h.T @ dy
  10. dby = dy
  11. # 注意加上dh_next这一项
  12. dh = dy @ Wy.T + dh_next
  13. # h = ho * tanh(c),计算ho的偏导数
  14. dho = tanh(c) * dh
  15. dho = dsigmoid(ho) * dho
  16. # h = ho * tanh(c), 计算c的偏导数
  17. dc = ho * dh * dtanh(c)
  18. dc = dc + dc_next
  19. # c = hf * c_old + hi * hc,计算hf的偏导数
  20. dhf = c_old * dc
  21. dhf = dsigmoid(hf) * dhf
  22. # c = hf * c_old + hi * hc,计算hi的偏导数
  23. dhi = hc * dc
  24. dhi = dsigmoid(hi) * dhi
  25. # c = hf * c_old + hi * hc,计算hc的偏导数
  26. dhc = hi * dc
  27. dhc = dtanh(hc) * dhc
  28. # 各个门的偏导数
  29. dWf = X.T @ dhf
  30. dbf = dhf
  31. dXf = dhf @ Wf.T
  32. dWi = X.T @ dhi
  33. dbi = dhi
  34. dXi = dhi @ Wi.T
  35. dWo = X.T @ dho
  36. dbo = dho
  37. dXo = dho @ Wo.T
  38. dWc = X.T @ dhc
  39. dbc = dhc
  40. dXc = dhc @ Wc.T
  41. # 由于X参与多个门的计算,因此偏导数需要累加
  42. dX = dXo + dXc + dXi + dXf
  43. # 计算h_old的偏导数
  44. dh_next = dX[:, :H]
  45. # c = hf * c_old + hi * hc,计算dc_next的偏导数
  46. dc_next = hf * dc
  47. grad = dict(Wf=dWf, Wi=dWi, Wc=dWc, Wo=dWo, Wy=dWy, bf=dbf, bi=dbi, bc=dbc, bo=dbo, by=dby)
  48. state = (dh_next, dc_next)
  49. return grad, state

在推导的过程中,不太容易理解地方的有如下几点:

  1. 计算dh时需要加上dh_next,因为在前向过程中,h不仅出现在y = h @ Wy + by,还与下一步计算有关。因此,这里不要忘记加上它。
  2. 计算dc时加上dc_next,理由同上。
  3. 计算dX时,需要累加dXo + dXc + dXi + dXf,理由与上面类似,因为X在多个计算步骤中都有用到。
  4. 因为X = [h_old, x],所以从dx可以得到dh_next。

既然正向和反向传播计算都已经实现,我们就可以合并两者来训练模型。

LSTM训练步骤

训练的过程分为三步:正向计算,计算损失值,反向计算。

  1. python
  2. def train_step(X_train, y_train, state):
  3. probs = []
  4. caches = []
  5. loss = 0.
  6. h, c = state
  7. # 正向计算
  8. for x, y_true in zip(X_train, y_train):
  9. prob, state, cache = lstm_forward(x, state, train=True)
  10. loss += cross_entropy(prob, y_true)
  11. # 保存正向计算的结果
  12. probs.append(prob)
  13. caches.append(cache)
  14. # 损失值采用交叉熵
  15. loss /= X_train.shape[0]
  16. # 反向过程
  17. # 在最后一步, dh_next 和 dc_next 的值等于0。
  18. d_next = (np.zeros_like(h), np.zeros_like(c))
  19. grads = {k: np.zeros_like(v) for k, v in model.items()}
  20. # 按照从后到前的时间顺序
  21. for prob, y_true, cache in reversed(list(zip(probs, y_train, caches))):
  22. grad, d_next = lstm_backward(prob, y_true, d_next, cache)
  23. # 累加各个步骤的梯度值
  24. for k in grads.keys():
  25. grads[k] += grad[k]
  26. return grads, loss, state

在一个完整的训练步骤中,我们首先进行前向计算,保存softmax层的概率分布结果以及每一步的中间结果,因为在反向过程中还会用到。

接着,我们在每一步都能计算交叉熵损失值(因为采用softmax方法)。然后,累加每一步的损失值,并求平均值。

最后,基于前向传播的结果进行反向传播运算,需要注意的是数据遍历的方向与之前相反。

另外,在反向传播的第一步,dh_next和dc_next的值等于0.为什么呢?这是因为在正向计算的最后一步,h和c不会参与下一步的计算,因为不存在下一步!因此,在最后一步h和c的偏导数可以直接推导,不需要考虑dh_next和dc_next。

一旦实现了这个函数,我们稍加修改就可以把它嵌入到任何优化算法中,比如RMSProp、Adam等等。

一切搞定!我们可以尝试训练一个LSTM模型

测试结果

使用Adam优化算法,我从维基百科上复制了一段关于文字。每一个字符表示一个数据。训练目标是预测文章的下一个字符。每隔100轮迭代,我们会检查一下模型的效果。下面是截取到的训练结果:

  1. =========================================================================
  2. Iter-100 loss: 4.2125
  3. =========================================================================
  4. best c ehpnpgteHihcpf,M tt" ao tpo Teoe ep S4 Tt5.8"i neai neyoserpiila o rha aapkhMpl rlp pclf5i
  5. =========================================================================
  6. ...
  7. =========================================================================
  8. Iter-52800 loss: 0.1233
  9. =========================================================================
  10. tary shoguns who ruled in the name of the Uprea wal motrko, the copulation of Japan is a sour the wa
  11. =========================================================================

模型果然学到了一些知识!

小结

在本文中,我们介绍了LSTM的通用公式,并基于此实现了前向计算过程。然后,我们推导了反向计算的过程,尽管加入了一些小技巧,但是整个过程还是非常直截了当。接着,我们将两者结合构建了完整的训练步骤,并用真实的数据训练和测试模型。

lstm-bp过程的手工源码实现的更多相关文章

  1. postgres创建表的过程以及部分源码分析

    背景:修改pg内核,在创建表时,表名不能和当前的用户名同名. 首先我们知道DefineRelation此函数是最终创建表结构的函数,最主要的参数是CreateStmt这个结构,该结构如下 typede ...

  2. 【cs229-Lecture2】Gradient Descent 最小二乘回归问题解析表达式推导过程及实现源码(无需迭代)

    视频地址:http://v.163.com/movie/2008/1/B/O/M6SGF6VB4_M6SGHJ9BO.html 机器学习课程的所有讲义及课后作业:http://pan.baidu.co ...

  3. MyBatis 源码分析 - 配置文件解析过程

    * 本文速览 由于本篇文章篇幅比较大,所以这里拿出一节对本文进行快速概括.本篇文章对 MyBatis 配置文件中常用配置的解析过程进行了较为详细的介绍和分析,包括但不限于settings,typeAl ...

  4. mybatis源码分析(1)——SqlSessionFactory实例的产生过程

    在使用mybatis框架时,第一步就需要产生SqlSessionFactory类的实例(相当于是产生连接池),通过调用SqlSessionFactoryBuilder类的实例的build方法来完成.下 ...

  5. 英蓓特Mars board的android4.0.3源码编译过程

    英蓓特Mars board的android4.0.3源码编译过程 作者:StephenZhu(大桥++) 2013年8月22日 若要转载,请注明出处 一.编译环境搭建及要点: 1. 虚拟机软件virt ...

  6. 2018-11-21 手工翻译Vue.js源码第一步:14个文件重命名

    背景 对现有开源项目的代码进行翻译(文件名/命名/注释) · Issue #107 · program-in-chinese/overview 简单地说, 通过翻译源码, 提高项目代码可读性(对于母语 ...

  7. MyBatis 源码分析 - SQL 的执行过程

    * 本文速览 本篇文章较为详细的介绍了 MyBatis 执行 SQL 的过程.该过程本身比较复杂,牵涉到的技术点比较多.包括但不限于 Mapper 接口代理类的生成.接口方法的解析.SQL 语句的解析 ...

  8. MyBatis 源码分析 - 映射文件解析过程

    1.简介 在上一篇文章中,我详细分析了 MyBatis 配置文件的解析过程.由于上一篇文章的篇幅比较大,加之映射文件解析过程也比较复杂的原因.所以我将映射文件解析过程的分析内容从上一篇文章中抽取出来, ...

  9. ViewPager源码分析——滑动切换页面处理过程

    上周客户反馈Contacts快速滑动界面切换tab有明显卡顿,让优化. 自己验证又没发现卡顿现象,但总得给客户一个技术性的回复,于是看了一下ViewPager源码中处理滑动切换tab的过程. View ...

随机推荐

  1. [转帖]Linux教程(21)-Linux条件循环语句

    Linux教程(21)-Linux条件循环语句 2018-08-24 16:49:03 钱婷婷 阅读数 60更多 分类专栏: Linux教程与操作 Linux教程与使用   版权声明:本文为博主原创文 ...

  2. python实战项目 — selenium登陆豆瓣

    利用selenium 模仿浏览器,登陆豆瓣 重点: 1. 要设置好 chromedriver配置与使用, chromedriver.exe 和 Chrome的浏览器版本要对应, http://chro ...

  3. day45——html常用标签、head内常用标签

    day45 MySQL内容回顾 数据库 DBMS mysql -RDBMS 关系型 数据库分类 关系型:mysql\oracle\sqlserver\access 非关系型:redis,mongodb ...

  4. day27——面向对象的总结、异常处理

    day27 面向对象的总结 异常处理 错误的分类 语法错误 if if 2>1 print(222) dic = {"name"; "alex"} 逻辑错 ...

  5. MySQL数据库的安装(Windows平台)

    1.MySQL数据库安装与配置 1.1 数据库安装和配置 安装需要注意的地方: 典型安装:安装最常用的特性组件,会默认安装至C盘目录下,适合大部分开发者. 自定义安装:可以自定义安装目录,自定义选择安 ...

  6. AS3.0绘图API

    AS3.0绘图API: /** * * *-------------------* * | *** 绘图API *** | * *-------------------* * * 编辑修改收录:fen ...

  7. Luogu4233 射命丸文的笔记 DP、多项式求逆

    传送门 注意到总共有\(\frac{n!}{n}\)条本质不同的哈密顿回路,每一条哈密顿回路恰好会出现在\(2^{\binom{n}{2} - n}\)个图中,所以我们实际上要算的是强连通有向竞赛图的 ...

  8. The three day 给你一个有效的 IPv4 地址 address,返回这个 IP 地址的无效化版本

    """ 给你一个有效的 IPv4 地址 address,返回这个 IP 地址的无效化版本. 所谓无效化 IP 地址,其实就是用 "[.]" 代替了每个 ...

  9. python --- 字符编码学习小结(二)

    距离上一篇的python --- 字符编码学习小结(一)已经过去2年了,2年的时间里,确实也遇到了各种各样的字符编码问题,也能解决,但是每次都是把所有的方法都试一遍,然后终于正常.这种方法显然是不科学 ...

  10. 网页包抓取工具Fiddler工具简单设置

    当下载好fiddler软件后首先通过以下简单设置,或者有时候fiddler抓取不了浏览器资源了.可以通过以下设置. 设置完成后重启软件.打开网络看看有没有抓取到包.