代码:

    def forward(self, x):
        '''
        根据式1-式6进行前向计算
        '''
        self.times += 1
        # 遗忘门
        fg = self.calc_gate(x, self.Wfx, self.Wfh,
                            self.bf, self.gate_activator)
        self.f_list.append(fg)
        # 输入门
        ig = self.calc_gate(x, self.Wix, self.Wih,
                            self.bi, self.gate_activator)
        self.i_list.append(ig)
        # 输出门
        og = self.calc_gate(x, self.Wox, self.Woh,
                            self.bo, self.gate_activator)
        self.o_list.append(og)
        # 即时状态
        ct = self.calc_gate(x, self.Wcx, self.Wch,
                            self.bc, self.output_activator)
        self.ct_list.append(ct)
        # 单元状态
        c = fg * self.c_list[self.times - 1] + ig * ct
        self.c_list.append(c)
        # 输出
        h = og * self.output_activator.forward(c)
        self.h_list.append(h)

    def calc_gate(self, x, Wx, Wh, b, activator):
        '''
        计算门
        '''
        h = self.h_list[self.times - 1]  # 上次的LSTM输出
        net = np.dot(Wh, h) + np.dot(Wx, x) + b
        gate = activator.forward(net)
        return gate

    def calc_delta_k(self, k):
        '''
        根据k时刻的delta_h,计算k时刻的delta_f、
        delta_i、delta_o、delta_ct,以及k-1时刻的delta_h
        '''
        # 获得k时刻前向计算的值
        ig = self.i_list[k]
        og = self.o_list[k]
        fg = self.f_list[k]
        ct = self.ct_list[k]
        c = self.c_list[k]
        c_prev = self.c_list[k - 1]
        tanh_c = self.output_activator.forward(c)
        delta_k = self.delta_h_list[k]

        # 根据式9计算delta_o
        delta_o = (delta_k * tanh_c *
                   self.gate_activator.backward(og))
        delta_f = (delta_k * og *
                   (1 - tanh_c * tanh_c) * c_prev *
                   self.gate_activator.backward(fg))
        delta_i = (delta_k * og *
                   (1 - tanh_c * tanh_c) * ct *
                   self.gate_activator.backward(ig))
        delta_ct = (delta_k * og *
                    (1 - tanh_c * tanh_c) * ig *
                    self.output_activator.backward(ct))
        delta_h_prev = (
                np.dot(delta_o.transpose(), self.Woh) +
                np.dot(delta_i.transpose(), self.Wih) +
                np.dot(delta_f.transpose(), self.Wfh) +
                np.dot(delta_ct.transpose(), self.Wch)
        ).transpose()

        # 保存全部delta值
        self.delta_h_list[k - 1] = delta_h_prev
        self.delta_f_list[k] = delta_f
        self.delta_i_list[k] = delta_i
        self.delta_o_list[k] = delta_o
        self.delta_ct_list[k] = delta_ct

    def calc_gradient_t(self, t):
        '''
        计算每个时刻t权重的梯度
        '''
        h_prev = self.h_list[t - 1].transpose()
        Wfh_grad = np.dot(self.delta_f_list[t], h_prev)
        bf_grad = self.delta_f_list[t]
        Wih_grad = np.dot(self.delta_i_list[t], h_prev)
        bi_grad = self.delta_f_list[t]
        Woh_grad = np.dot(self.delta_o_list[t], h_prev)
        bo_grad = self.delta_f_list[t]
        Wch_grad = np.dot(self.delta_ct_list[t], h_prev)
        bc_grad = self.delta_ct_list[t]
        return Wfh_grad, bf_grad, Wih_grad, bi_grad, \
               Woh_grad, bo_grad, Wch_grad, bc_grad

    def calc_gradient(self, x):
        # 初始化遗忘门权重梯度矩阵和偏置项
        self.Wfh_grad, self.Wfx_grad, self.bf_grad = (
            self.init_weight_gradient_mat())
        # 初始化输入门权重梯度矩阵和偏置项
        self.Wih_grad, self.Wix_grad, self.bi_grad = (
            self.init_weight_gradient_mat())
        # 初始化输出门权重梯度矩阵和偏置项
        self.Woh_grad, self.Wox_grad, self.bo_grad = (
            self.init_weight_gradient_mat())
        # 初始化单元状态权重梯度矩阵和偏置项
        self.Wch_grad, self.Wcx_grad, self.bc_grad = (
            self.init_weight_gradient_mat())

        # 计算对上一次输出h的权重梯度
        for t in range(self.times, 0, -1):
            # 计算各个时刻的梯度
            (Wfh_grad, bf_grad,
             Wih_grad, bi_grad,
             Woh_grad, bo_grad,
             Wch_grad, bc_grad) = (
                self.calc_gradient_t(t))
            # 实际梯度是各时刻梯度之和
            self.Wfh_grad += Wfh_grad
            self.bf_grad += bf_grad
            self.Wih_grad += Wih_grad
            self.bi_grad += bi_grad
            self.Woh_grad += Woh_grad
            self.bo_grad += bo_grad
            self.Wch_grad += Wch_grad
            self.bc_grad += bc_grad

        # 计算对本次输入x的权重梯度
        xt = x.transpose()
        self.Wfx_grad = np.dot(self.delta_f_list[-1], xt)
        self.Wix_grad = np.dot(self.delta_i_list[-1], xt)
        self.Wox_grad = np.dot(self.delta_o_list[-1], xt)
        self.Wcx_grad = np.dot(self.delta_ct_list[-1], xt)

参考:

https://zybuluo.com/hanbingtao/note/581764

https://www.cnblogs.com/ratels/p/11416515.html

零基础入门深度学习(6) - 长短时记忆网络(LSTM)的更多相关文章

  1. (转)零基础入门深度学习(6) - 长短时记忆网络(LSTM)

    无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的技术,会不会感觉马上就o ...

  2. C#区块链零基础入门,学习路线图 转

    C#区块链零基础入门,学习路线图 一.1分钟短视频<区块链100问>了解区块链基本概念 http://tech.sina.com.cn/zt_d/blockchain_100/ 二.C#区 ...

  3. 长短时记忆网络(LSTM)

    长短时记忆网络 循环神经网络很难训练的原因导致它的实际应用中很处理长距离的依赖.本文将介绍改进后的循环神经网络:长短时记忆网络(Long Short Term Memory Network, LSTM ...

  4. 【零基础学深度学习】动手学深度学习2.0--tensorboard可视化工具简单使用

    1 引言 老师让我将线性回归训练得出的loss值进行可视化,于是我使用了tensorboard将其应用到Pytorch中,用于Pytorch的可视化. 2 环境安装 本教程代码环境依赖: python ...

  5. 长短时记忆网络LSTM和条件随机场crf

    LSTM 原理 CRF 原理 给定一组输入随机变量条件下另一组输出随机变量的条件概率分布模型.假设输出随机变量构成马尔科夫随机场(概率无向图模型)在标注问题应用中,简化成线性链条件随机场,对数线性判别 ...

  6. 机器学习与Tensorflow(5)——循环神经网络、长短时记忆网络

    1.循环神经网络的标准模型 前馈神经网络能够用来建立数据之间的映射关系,但是不能用来分析过去信号的时间依赖关系,而且要求输入样本的长度固定 循环神经网络是一种在前馈神经网络中增加了分亏链接的神经网络, ...

  7. 函数:我的地盘听我的 - 零基础入门学习Python019

    函数:我的地盘听我的 让编程改变世界 Change the world by program 函数与过程 在小甲鱼另一个实践性超强的编程视频教学<零基础入门学习Delphi>中,我们谈到了 ...

  8. 【Python教程】《零基础入门学习Python》(小甲鱼)

    [Python教程]<零基础入门学习Python>(小甲鱼) 讲解通俗易懂,诙谐. 哈哈哈. https://www.bilibili.com/video/av27789609

  9. 《零基础入门学习Python》【第一版】视频课后答案第001讲

    测试题答案: 0. Python 是什么类型的语言? Python是脚本语言 脚本语言(Scripting language)是电脑编程语言,因此也能让开发者藉以编写出让电脑听命行事的程序.以简单的方 ...

随机推荐

  1. 【vue】axios + cookie + 跳转登录方法

    axios 部分: import axios from 'axios' import cookie from './cookie.js' // import constVal from './cons ...

  2. vue项目中解决跨域问题axios和

    项目如果是用脚手架搭建的(vue cli)项目配置文件里有个proxyTable proxyTable是vue-cli搭建webpack脚手架中的一个微型代理服务器,配置如下 配置和安装axios 安 ...

  3. 使用git上传项目解决码云文件次数上传限制(原文)

    起因:个人免费版的码云上传文件时限制: 1个小时内只能上传20个文件 解决方法:在码云创建空的项目仓库,使用git客户端下载码云的项目,把需要上传的文件复制到该项目中去,用git提交! 1.配置git ...

  4. android如何让checkbox实现互斥以及android验证端cession登录注意事项

    1.CheckBox有一个监听器OnChangedListener,每次选择checkbox都会触发这个事件, 里边有一个参数isChecked,就是判断checkbox是否已经选上了的,可以在这判断 ...

  5. 并发编程Semaphore详解

    Semaphore的作用:限制线程并发的数量 位于 java.util.concurrent 下, 构造方法 // 构造函数 代表同一时间,最多允许permits执行acquire() 和releas ...

  6. 【转】解决jenkins自动杀掉衍生进程

    在执行 shell输入框中加入BUILD_ID=dontKillMe ,即可防止jenkins杀死启动的进程 export BUILD_ID=dontKillMe PROJECT_LOCATION=& ...

  7. C\C++改变鼠标样式

    改变鼠标样式可以使用SetClassLong函数 HCURSOR hcur = LoadCursor(NULL, IDC_CROSS); //加载系统自带鼠标样式 HWND hwnd = GetHWn ...

  8. git按需过滤提交文件的一个细节

    问题场景 用git管理代码时,作为git小白的我总会遇到一些无法理解的问题,在请教了一些高手后终于解开了疑惑,参考以下场景: 1.比如我们已在电脑1上完成用vs编辑项目.添加.提交到服务器的完整流程, ...

  9. C:防止头文件重复包含

    当一个项目比较大时,往往都是分文件,这时候有可能不小心把同一个头文件 include 多次,或者头文件嵌套包含. 方法一: #ifndef __SOMEFILE_H__ #define __SOMEF ...

  10. win10无法登陆SSG进行WEB UI管理

    故障描述:尝试登录SSG设备时,无法无法刷出页面,但是设备时可以ping通的(内部接口),可以Telnet上设备,就是无法通过网页登录. 深入测试:win7的系统可以登录,win10的不行,浏览器报协 ...