本文中的RNN泛指LSTM,GRU等等
CNN中和RNNbatchSize的默认位置是不同的。

  • CNN中:batchsize的位置是position 0.
  • RNN中:batchsize的位置是position 1.

在RNN中输入数据格式:

对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell(),它只接受序列中的单步输入,必须显式的传入隐藏状态torch.nn.RNN()可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。

  1. 输入大小是三维tensor[seq_len,batch_size,input_dim]
  • input_dim是输入的维度,比如是128
  • batch_size是一次往RNN输入句子的数目,比如是5
  • seq_len是一个句子的最大长度,比如15
    所以千万注意,RNN输入的是序列,一次把批次的所有句子都输入了,得到的ouptuthidden都是这个批次的所有的输出和隐藏状态,维度也是三维。
    **可以理解为现在一共有batch_size个独立的RNN组件,RNN的输入维度是input_dim,总共输入seq_len个时间步,则每个时间步输入到这个整个RNN模块的维度是[batch_size,input_dim]
# 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
rnn_seq = nn.RNN(5, 10,2)
# 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
x = torch.randn(6, 3, 5)
#out,ht = rnn_seq(x,h0)
out,ht = rnn_seq(x) #h0可以指定或者不指定

问题1:这里outhtsize是多少呢?
回答out:6 * 3 * 10, ht: 2 * 3 * 10,out的输出维度[seq_len,batch_size,output_dim],ht的维度[num_layers * num_directions, batch, hidden_size],如果是单向单层的RNN那么一个句子只有一个hidden
问题2out[-1]ht[-1]是否相等?
回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元

  1. RNN的其他参数
RNN(input_dim ,hidden_dim ,num_layers ,…)
– input_dim 表示输入的特征维度
– hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
– num_layers 表示网络的层数
– nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
– bias 表示是否使用偏置,默认使用
– batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
– dropout 表示是否在输出层应用 dropout
– bidirectional 表示是否使用双向的 rnn,默认是 False
 
向RNN中输入的tensor的形状

LSTM的输出多了一个memory单元

# 输入维度 50,隐层100维,两层
lstm_seq = nn.LSTM(50, 100, num_layers=2)
# 输入序列seq= 10,batch =3,输入维度=50
lstm_input = torch.randn(10, 3, 50)
out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态

问题1out(h,c)的size各是多少?
回答out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)
问题2out[-1,:,:]h[-1,:,:]相等吗?
回答: 相等

GRU比较像传统的RNN

gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
out, h = gru_seq(gru_input)

作者:VanJordan
链接:https://www.jianshu.com/p/b942e65cb0a3
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

[PyTorch] rnn,lstm,gru中输入输出维度的更多相关文章

  1. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  2. RNN,LSTM,GRU基本原理的个人理解

    记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...

  3. RNN/LSTM/GRU/seq2seq公式推导

    概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...

  4. RNN - LSTM - GRU

    循环神经网络 (Recurrent Neural Network,RNN) 是一类具有短期记忆能力的神经网络,因而常用于序列建模.本篇先总结 RNN 的基本概念,以及其训练中时常遇到梯度爆炸和梯度消失 ...

  5. RNN & LSTM & GRU 的原理与区别

      RNN 循环神经网络,是非线性动态系统,将序列映射到序列,主要参数有五个:[Whv,Whh,Woh,bh,bo,h0][Whv,Whh,Woh,bh,bo,h0],典型的结构图如下: 和普通神经网 ...

  6. RNN, LSTM, GRU cells

    项目需要,先简记cell,有时间再写具体改进原因 RNN cell LSTM cell: GRU cell: reference: 1.https://towardsdatascience.com/a ...

  7. 【pytorch学习笔记0】-CNN与LSTM输入输出维度含义

    卷积data的四个维度: batch, input channel, height, width Conv2d的四个维度: input channel, output channel, kernel, ...

  8. NLP教程(5) - 语言模型、RNN、GRU与LSTM

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/36 本文地址:http://www.showmeai.tech/article-det ...

  9. RNN,GRU,LSTM

    2019-08-29 17:17:15 问题描述:比较RNN,GRU,LSTM. 问题求解: 循环神经网络 RNN 传统的RNN是维护了一个隐变量 ht 用来保存序列信息,ht 基于 xt 和 ht- ...

随机推荐

  1. linux 忘记root(这里以centos 6.5为例)密码的解决办法

    在使用linux的过程中有时候会忘记root用户的密码(尤其是进行交接而文档内容不全的时候),这个时候我们就可以进入单用户模式来重置root用户密码.下面来讲解重置root密码的方式,也可以说是破解r ...

  2. snnu1111(子序列求和)

    1111: 子序列求和 Time Limit: 3 Sec  Memory Limit: 64 MBSubmit: 10  Solved: 2[Submit][Status][Web Board] [ ...

  3. 基于dubbo2.5.5+zookeeper3.4.9的服务搭建

    参考资料:https://segmentfault.com/a/1190000009568509https://segmentfault.com/a/1190000004654903 0. 环境 Ja ...

  4. hdu4604 Deque(最长上升子序列变形)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4604 题意:一个含有n个数栈,每次取出一个数,可以把这个数放在deque(双向队列)首部,放在尾部,或 ...

  5. fiddler工具使用

    一.安装 a) 百度fiddler ,下载, 安装 ,无脑流 二.界面介绍 a) 工具栏与状态栏 其中保存是,可以为两种形式:1.text文本形式 2.saz文件结尾数据(能被fiddler软件识别) ...

  6. FOJ2252:Yu-Gi-Oh!

    传送门 题意 略 分析 枚举合成ab的数量,在此基础上合成bc和ac,复杂度\(O(n)\) trick 代码 #include<cstdio> #include<algorithm ...

  7. 洛谷 P3254 圆桌问题【最大流】

    s向所有单位连流量为人数的边,所有饭桌向t连流量为饭桌容量的边,每个单位向每个饭桌连容量为1的边表示这个饭桌只能坐这个单位的一个人.跑dinic如果小于总人数则无解,否则对于每个单位for与它相连.满 ...

  8. jsp标准标签库——jstl

    JSP标准标签库(JSTL)是一个JSP标签集合,它封装了JSP应用的通用核心功能. JSTL支持通用的.结构化的任务,比如迭代,条件判断,XML文档操作,国际化标签,SQL标签. 除了这些,它还提供 ...

  9. DP(两次) UVA 10163 Storage Keepers

    题目传送门 /* 题意:(我懒得写,照搬网上的)有n个仓库,m个人看管.一个仓库只能由一个人来看管,一个人可以看管多个仓库. 每个人有一个能力值pi,如果他看管k个仓库,那么所看管的每个仓库的安全值为 ...

  10. synchronized(1)用法简介:修饰方法,修饰语句块

    注意: 同一个对象或方法在不同线程中才出现同步问题,不同对象在不同线程互相不干扰. synchronized方法有2种用法:修饰方法,修饰语句块 1.synchronized方法 是某个对象实例内,s ...