pytorch, LSTM介绍
本文中的RNN泛指LSTM,GRU等等
CNN中和RNN中batchSize的默认位置是不同的。
- CNN中:batchsize的位置是
position 0
. - RNN中:batchsize的位置是
position 1
.
在RNN中输入数据格式:
对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell()
,它只接受序列中的单步输入,必须显式的传入隐藏状态。torch.nn.RNN()
可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。
- 输入大小是三维tensor
[seq_len,batch_size,input_dim]
input_dim
是输入的维度,比如是128
batch_size
是一次往RNN
输入句子的数目,比如是5
。seq_len
是一个句子的最大长度,比如15
所以千万注意,RNN
输入的是序列,一次把批次的所有句子都输入了,得到的ouptut
和hidden
都是这个批次的所有的输出和隐藏状态,维度也是三维。
**可以理解为现在一共有batch_size
个独立的RNN
组件,RNN
的输入维度是input_dim
,总共输入seq_len
个时间步,则每个时间步输入到这个整个RNN
模块的维度是[batch_size,input_dim]
# 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10,2)
# 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out,ht = rnn_seq(x) #h0可以指定或者不指定
问题1:这里out
、ht
的size是多少呢?
回答:out
:6 * 3 * 10, ht
: 2 * 3 * 10,out
的输出维度[seq_len,batch_size,output_dim]
,ht的维度[num_layers * num_directions, batch, hidden_size]
,如果是单向单层的RNN那么一个句子只有一个hidden
。
问题2:out[-1]
和ht[-1]
是否相等?
回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元
RNN
的其他参数
RNN(input_dim ,hidden_dim ,num_layers ,…)
– input_dim 表示输入的特征维度
– hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
– num_layers 表示网络的层数
– nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
– bias 表示是否使用偏置,默认使用
– batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
– dropout 表示是否在输出层应用 dropout
– bidirectional 表示是否使用双向的 rnn,默认是 False
LSTM的输出多了一个memory单元
# 输入维度 50,隐层100维,两层
lstm_seq = nn.LSTM(50, 100, num_layers=2)
# 输入序列seq= 10,batch =3,输入维度=50
lstm_input = torch.randn(10, 3, 50)
out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态
问题1:out
和(h,c)
的size各是多少?
回答:out
:(10 * 3 * 100),(h,c)
:都是(2 * 3 * 100)
问题2:out[-1,:,:]
和h[-1,:,:]
相等吗?
回答: 相等
GRU比较像传统的RNN
gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
out, h = gru_seq(gru_input)
pytorch, LSTM介绍的更多相关文章
- pytorch学习笔记(九):PyTorch结构介绍
PyTorch结构介绍对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握.水平有限,如有错误,欢迎指错,谢谢! 几个重要的类型和数值相关的Tens ...
- 网络流量预测入门(二)之LSTM介绍
目录 网络流量预测入门(二)之LSTM介绍 LSTM简介 Simple RNN的弊端 LSTM的结构 细胞状态(Cell State) 门(Gate) 遗忘门(Forget Gate) 输入门(Inp ...
- LSTM介绍
转自:https://blog.csdn.net/gzj_1101/article/details/79376798 LSTM网络 long short term memory,即我们所称呼的LSTM ...
- RNN LSTM 介绍
[RNN以及LSTM的介绍和公式梳理]http://blog.csdn.net/Dark_Scope/article/details/47056361 [知乎 对比 rnn lstm 简单代码] ...
- pytorch lstm crf 代码理解 重点
好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...
- pytorch lstm crf 代码理解
好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...
- Pytorch LSTM 词性判断
首先,我们定义好一个LSTM网络,然后给出一个句子,每个句子都有很多个词构成,每个词可以用一个词向量表示,这样一句话就可以形成一个序列,我们将这个序列依次传入LSTM,然后就可以得到与序列等长的输出, ...
- pytorch LSTM情感分类全部代码
先运行main.py进行文本序列化,再train.py模型训练 dataset.py from torch.utils.data import DataLoader,Dataset import to ...
- RNN、LSTM介绍以及梯度消失问题讲解
写在最前面,感谢这两篇文章,基本上的框架是从这两篇文章中得到的: https://zhuanlan.zhihu.com/p/28687529 https://zhuanlan.zhihu.com/p/ ...
随机推荐
- GO语言系列(二)- 基本数据类型和操作符
一.文件名 & 关键字 & 标识符 1.所有go源码以.go结尾 2.标识符以字母或下划线开头,大小写敏感 3._是特殊标识符,用来忽略结果 4.保留关键字 二.Go程序的基本结构 p ...
- JGUI源码:Accordion折叠到侧边栏实现(6)
折叠和非折叠效果如左右图所示 代码如下 //折叠 $.fn.jAccordionfold = function() { return this.each(function() { var obj = ...
- 在 Visual Studio 中使用 IntelliTrace 快照功能
今天发现vs2017 IntelliTrace有了一个快照功能,测试一下它的用法 1.选项->IntelliTrace->选择第三项 2.建一个控制台应用程序 3.在main中写一个简单的 ...
- SQL Server安全
第一篇 SQL Server安全概述 第二篇 SQL Server安全验证 第三篇 SQL Server安全主体和安全对象 第四篇 SQL Server安全权限 第五篇 SQL Server安全架构和 ...
- 让iframe自适应高度-真正解决
需求:实现 iframe 的自适应高度,能够随着页面的长度自动的适应以免除页面和 iframe 同时出现滚动条的现象. (需要只有iframe出现滚动条) 本人一开始这么写:会造成只有主页面加载是设定 ...
- mysql和SQLYog工具使用
在做 android java web开发过程中,需要用到mysql数据库,使用可视化工具SQLYog管理工具,现介绍如下: 首先下载mysql工具 下载地址 https://dev.mysql. ...
- SQL Server - JOIN
JOIN
- 007_wireshark分析TCP的三次握手和四次断开
要想进行抓包分析,必须先了解TCP的原理.这里介绍了TCP的建立连接的三次握手和断开连接的四次握手. 一.前言:介绍三次握手之前,先介绍TCP层的几个FLAGS字段,这个字段有如下的几种标示 SYN表 ...
- .Net Core---- 通过EPPlus批量导出
前台代码: 前台代码是在.net core bootstrap集成框架上的(这是效果浏览地址:http://core.jucheap.com[效果地址来自:http://blog.csdn.net/a ...
- 【原创】大叔经验分享(19)spark on yarn提交任务之后执行进度总是10%
spark 2.1.1 系统中希望监控spark on yarn任务的执行进度,但是监控过程发现提交任务之后执行进度总是10%,直到执行成功或者失败,进度会突然变为100%,很神奇, 下面看spark ...