机器学习(ML)九之GRU、LSTM、深度神经网络、双向循环神经网络
门控循环单元(GRU)
循环神经网络中的梯度计算方法。当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题。通常由于这个原因,循环神经网络在实际中较难捕捉时间序列中时间步距离较大的依赖关系。
门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系。它通过可以学习的门来控制信息的流动。其中,门控循环单元(gated recurrent unit,GRU)是一种常用的门控循环神经网络。
门控循环单元
门控循环单元的设计。它引入了重置门(reset gate)和更新门(update gate)的概念,从而修改了循环神经网络中隐藏状态的计算方式。
重置门和更新门
门控循环单元中的重置门和更新门的输入均为当前时间步输入Xt与上一时间步隐藏状态Ht−1,输出由激活函数为sigmoid函数的全连接层计算得到。
候选隐藏状态
隐藏状态
代码实现
#!/usr/bin/env python
# coding: utf-8 # In[10]: import d2lzh as d2l
from mxnet import nd
from mxnet.gluon import rnn
import zipfile # In[11]: def load_data_jay_lyrics(file):
"""Load the Jay Chou lyric data set (available in the Chinese book)."""
with zipfile.ZipFile(file) as zin:
with zin.open('jaychou_lyrics.txt') as f:
corpus_chars = f.read().decode('utf-8')
corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ')
corpus_chars = corpus_chars[0:10000]
idx_to_char = list(set(corpus_chars))
char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
vocab_size = len(char_to_idx)
corpus_indices = [char_to_idx[char] for char in corpus_chars]
return corpus_indices, char_to_idx, idx_to_char, vocab_size # In[12]: file ='/Users/James/Documents/dev/test/data/jaychou_lyrics.txt.zip'
(corpus_indices, char_to_idx, idx_to_char, vocab_size) = load_data_jay_lyrics(file) # In[13]: num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
ctx = d2l.try_gpu() def get_params():
def _one(shape):
return nd.random.normal(scale=0.01, shape=shape, ctx=ctx) def _three():
return (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
nd.zeros(num_hiddens, ctx=ctx)) W_xz, W_hz, b_z = _three() # 更新门参数
W_xr, W_hr, b_r = _three() # 重置门参数
W_xh, W_hh, b_h = _three() # 候选隐藏状态参数
# 输出层参数
W_hq = _one((num_hiddens, num_outputs))
b_q = nd.zeros(num_outputs, ctx=ctx)
# 附上梯度
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.attach_grad()
return params # In[14]: def init_gru_state(batch_size, num_hiddens, ctx):
return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx), ) # In[15]: def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = nd.sigmoid(nd.dot(X, W_xz) + nd.dot(H, W_hz) + b_z)
R = nd.sigmoid(nd.dot(X, W_xr) + nd.dot(H, W_hr) + b_r)
H_tilda = nd.tanh(nd.dot(X, W_xh) + nd.dot(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = nd.dot(H, W_hq) + b_q
outputs.append(Y)
return outputs, (H,) # In[16]: num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开'] # In[ ]: d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
vocab_size, ctx, corpus_indices, idx_to_char,
char_to_idx, False, num_epochs, num_steps, lr,
clipping_theta, batch_size, pred_period, pred_len,
prefixes)
长短期记忆(LSTM)
常用的门控循环神经网络:长短期记忆(long short-term memory,LSTM)。它比门控循环单元的结构稍微复杂一点。
长短期记忆
LSTM 中引入了3个门,即输入门(input gate)、遗忘门(forget gate)和输出门(output gate),以及与隐藏状态形状相同的记忆细胞(某些文献把记忆细胞当成一种特殊的隐藏状态),从而记录额外的信息。
输入门、遗忘门和输出门
候选记忆细胞
记忆细胞
隐藏状态
代码实现
#LSTM 初始化参数
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
ctx = d2l.try_gpu() def get_params():
def _one(shape):
return nd.random.normal(scale=0.01, shape=shape, ctx=ctx) def _three():
return (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
nd.zeros(num_hiddens, ctx=ctx)) W_xi, W_hi, b_i = _three() # 输入门参数
W_xf, W_hf, b_f = _three() # 遗忘门参数
W_xo, W_ho, b_o = _three() # 输出门参数
W_xc, W_hc, b_c = _three() # 候选记忆细胞参数
# 输出层参数
W_hq = _one((num_hiddens, num_outputs))
b_q = nd.zeros(num_outputs, ctx=ctx)
# 附上梯度
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.attach_grad()
return params # In[19]: def init_lstm_state(batch_size, num_hiddens, ctx):
return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx),
nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))
深度循环神经网络
双向循环神经网络
机器学习(ML)九之GRU、LSTM、深度神经网络、双向循环神经网络的更多相关文章
- 深度学习之循环神经网络RNN概述,双向LSTM实现字符识别
深度学习之循环神经网络RNN概述,双向LSTM实现字符识别 2. RNN概述 Recurrent Neural Network - 循环神经网络,最早出现在20世纪80年代,主要是用于时序数据的预测和 ...
- 深度学习之循环神经网络(RNN)
循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络,适合用于处理视频.语音.文本等与时序相关的问题.在循环神经网络中,神经元不但可以接收其他神经元 ...
- TensorFlow深度学习实战---循环神经网络
循环神经网络(recurrent neural network,RNN)-------------------------重要结构(长短时记忆网络( long short-term memory,LS ...
- 学习笔记TF057:TensorFlow MNIST,卷积神经网络、循环神经网络、无监督学习
MNIST 卷积神经网络.https://github.com/nlintz/TensorFlow-Tutorials/blob/master/05_convolutional_net.py .Ten ...
- Keras(四)CNN 卷积神经网络 RNN 循环神经网络 原理及实例
CNN 卷积神经网络 卷积 池化 https://www.cnblogs.com/peng8098/p/nlp_16.html 中有介绍 以数据集MNIST构建一个卷积神经网路 from keras. ...
- TensorFlow深度学习笔记 循环神经网络实践
转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnotes 欢迎star,有问题可以到Issue区讨论 官方教程地址 视频/字幕下载 加 ...
- 开始学习深度学习和循环神经网络Some starting points for deep learning and RNNs
Bengio, LeCun, Jordan, Hinton, Schmidhuber, Ng, de Freitas and OpenAI have done reddit AMA's. These ...
- 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集
#加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...
- 【TensorFlow入门完全指南】神经网络篇·循环神经网络(RNN)
第一步仍然是导入库和数据集. ''' To classify images using a reccurent neural network, we consider every image row ...
随机推荐
- Jenkins2构建pipeline流水线
流水线有两种方式: 1.脚本式流水线 2.声明式流水线 构建流水线的简单示例: 脚本式流水线 node ('master'){ stage("Source"){ //从Git仓库中 ...
- 移动web 1像素边框
实现方法 border-image 图片 实现 这篇文章是腾讯github上的解决方案border-image来实现的 链接走起 <使用border-image实现类似iOS7的1px底边> ...
- 【Java基础总结】字符串
1. java内存区域(堆区.栈区.常量池) 2. String length() //长度 //获取子串位置 indexOf(subStr) lastIndexOf(subStr) //获取子串 c ...
- Java读取数据库中的xml格式内容,解析后修改属性节点内容并写回数据库
直接附代码: 1.测试用的xml内容 <mxGraphModel> <root> <mxCell id="-1" /> <mxCell i ...
- 低功耗蓝牙(BLE)——概述
1. 概述 蓝牙协议是由SIG制定并维护的无线通信协议,蓝牙协议栈是蓝牙协议的具体实现.各厂商都根据蓝牙协议实现了自己的一套函数库--蓝牙协议栈,所以不同厂商的蓝牙协议栈之间虽然存在差别,但是都遵 ...
- 1、纯python编写学生信息管理系统
1.效果图 2.python code: class studentSys(object): ''' _init_(self) 被称为类的构造函数或初始化方法, self 代表类的实例,self 在定 ...
- 序列积第m小元素 二分答案优化
给出两个长度为n的数组A和B, 在A和B中各任取一个, 可以得到n×n个积. 求第m小的元素. n<=100000 这一道题的意思就是 a1 a2 a3 a4.. b1 b2 b3 b4 n^2 ...
- Spring Boot 入门(十二):报表导出,对比poi、jxl和esayExcel的效率
本片博客是紧接着Spring Boot 入门(十一):集成 WebSocket, 实时显示系统日志写的 关于poi.jxl和esayExcel的介绍自行百度. jxl最多支持03版excel,所以单个 ...
- 如何构建可伸缩的Web应用?
为什么要构建可伸缩的Web应用? 想象一下,你的营销活动吸引了很多用户,在某个时候,应用必须同时为成千上万的用户提供服务,这么大的并发量,服务器的负载会很大,如果设计不当,系统将无法处理. 接下来发生 ...
- ssm之spring+springmvc+mybatis整合初探
1.基本目录如下 2.首先是向lib中加入相应的jar包 3.然后在web.xml中加入配置,使spring和springmvc配置文件起作用. <?xml version="1. ...