RNN基础:

『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练

TensorFlow RNN:

『TensotFlow』基础RNN网络分类问题

『TensotFlow』基础RNN网络回归问题

『TensotFlow』深层循环神经网络

『TensotFlow』LSTM古诗生成任务总结

对于torch中的RNN相关类,有原始和原始Cell之分,其中RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。实际上RNN层的一种后端实现方式就是调用RNNCell来实现的。

一、nn.RNN

import torch as t
from torch import nn
from torch.autograd import Variable as V layer = 1 t.manual_seed(1000)
# 3句话,每句话2个字,每个字4维矢量
# batch为3,step为2,每个元素4维
input = V(t.randn(2,3,4))
# 1层,输出(隐藏)神经元3维,输入神经元4维
# 1层,3隐藏神经元,每个元素4维
lstm = nn.LSTM(4,3,layer)
# 初始状态:1层,batch为3,隐藏神经元3
h0 = V(t.randn(layer,3,3))
c0 = V(t.randn(layer,3,3)) out, hn = lstm(input,(h0,c0))
print(out, hn)
Variable containing:
(0 ,.,.) =
0.0545 -0.0061 0.5615
-0.1251 0.4490 0.2640
0.1405 -0.1624 0.0303 (1 ,.,.) =
0.0168 0.1562 0.5002
0.0824 0.1454 0.4007
0.0180 -0.0267 0.0094
[torch.FloatTensor of size 2x3x3]
(Variable containing:
(0 ,.,.) =
0.0168 0.1562 0.5002
0.0824 0.1454 0.4007
0.0180 -0.0267 0.0094
[torch.FloatTensor of size 1x3x3]
, Variable containing:
(0 ,.,.) =
0.1085 0.1957 0.9778
0.5397 0.2874 0.6415
0.0480 -0.0345 0.0141
[torch.FloatTensor of size 1x3x3]
)

二、nn.RNNCell

import torch as t
from torch import nn
from torch.autograd import Variable as V t.manual_seed(1000)
# batch为3,step为2,每个元素4维
input = V(t.randn(2,3,4))
# Cell只能是1层,3隐藏神经元,每个元素4维
lstm = nn.LSTMCell(4,3)
# 初始状态:1层,batch为3,隐藏神经元3
hx = V(t.randn(3,3))
cx = V(t.randn(3,3)) out = [] # 每个step提取各个batch的四个维度
for i_ in input:
print(i_.shape)
hx, cx = lstm(i_,(hx,cx))
out.append(hx)
t.stack(out)
torch.Size([3, 4])
torch.Size([3, 4])
Variable containing:
(0 ,.,.) =
0.0545 -0.0061 0.5615
-0.1251 0.4490 0.2640
0.1405 -0.1624 0.0303 (1 ,.,.) =
0.0168 0.1562 0.5002
0.0824 0.1454 0.4007
0.0180 -0.0267 0.0094
[torch.FloatTensor of size 2x3x3]

三、nn.Embedding

embedding将标量表示的字符(所以是LongTensor)转换成矢量,这里给出一个模拟:将标量词embedding后送入rnn转换一下维度。

import torch as t
from torch import nn
from torch.autograd import Variable as V # 5个词,每个词使用4维向量表示
embedding = nn.Embedding(5, 4)
# 使用预训练好的词向量初始化
embedding.weight.data = t.arange(0, 20).view(5, 4) # 大小对应nn.Embedding(5, 4) # embedding将标量表示的字符(所以是LongTensor)转换成矢量
# 实际输入词原始向量需要是LongTensor格式
input = V(t.arange(3, 0, -1)).long()
# 1个batch,3个step,4维矢量
input = embedding(input).unsqueeze(1)
print("embedding后:",input.size()) # 1层,3隐藏神经元(输出元素4维度),每个元素4维
layer = 1
lstm = nn.LSTM(4, 3, layer)
# 初始状态:1层,batch为3,隐藏神经元3
h0 = V(t.randn(layer, 3, 3))
c0 = V(t.randn(layer, 3, 3))
out, hn = lstm(input, (h0, c0))
print("LSTM输出:",out.size())
embedding后: torch.Size([3, 1, 4])
LSTM输出: torch.Size([3, 3, 3])

『PyTorch』第十弹_循环神经网络的更多相关文章

  1. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  2. 『MXNet』第十弹_物体检测SSD

    全流程地址 一.辅助API介绍 mxnet.image.ImageDetIter 图像检测迭代器, from mxnet import image from mxnet import nd data_ ...

  3. 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor

    Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...

  4. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  5. 『TensorFlow』第十弹_队列&多线程_道路多坎坷

    一.基本队列: 队列有两个基本操作,对应在tf中就是enqueue&dequeue tf.FIFOQueue(2,'int32') import tensorflow as tf '''FIF ...

  6. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  7. 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数

    一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...

  8. 『PyTorch』第五弹_深入理解autograd_中:Variable梯度探究

    查看非叶节点梯度的两种方法 在反向传播过程中非叶子节点的导数计算完之后即被清空.若想查看这些变量的梯度,有两种方法: 使用autograd.grad函数 使用hook autograd.grad和ho ...

  9. 『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较_&_广播原理简介

    一.简单数学操作 1.逐元素操作 t.clamp(a,min=2,max=4)近似于tf.clip_by_value(A, min, max),修剪值域. a = t.arange(0,6).view ...

随机推荐

  1. Linux基础命令---mkfs

    mkfs 在磁盘分区上创建ext2.ext3.ext4.ms-dos.vfat文件系统,默认情况下会创建ext2.mkfs用于在设备上构建Linux文件系统,通常是硬盘分区.文件要么是设备名称(例如/ ...

  2. 关于hibernate中的session与数据库连接关系以及getCurrentSession 与 openSession() 的区别

    1.session与connection,是多对一关系,每个session都有一个与之对应的connection,一个connection不同时刻可以供多个session使用.   2.多个sessi ...

  3. Linux中Postfix邮件安装Maildrop(八)

    Postfix使用maildrop投递邮件 Maildrop是本地邮件投递代理(MDA), 支持过滤(/etc/maildroprc).投递和磁盘限额(Quota)功能. Maildrop是一个使用C ...

  4. VMware前路难测,多个厂家群雄逐鹿

    以VMware为例,虚拟机巨头公布了第二财季报告所示,它第二财季收入同比增长13%,达到了21.7亿美元,而且该公司收入和每股收益均超出预期. 在人们高谈Salesforce.亚马逊等新兴云计算厂商取 ...

  5. 查询oracle数据库里面所有的表名

    如果是当前用户,"select * from tab"即可

  6. 微信小程序新闻列表功能(读取文件、template模板使用)

    微信小程序新闻列表功能(读取文件.template) 不忘初心,方得始终.初心易得,始终难守. 在之前的项目基础上进行修改,实现读取文件内容作为新闻内容进行展示. 首先,修改 post.wxml 文件 ...

  7. 2019“嘉韦思”杯RSA256题目wp

    首先我们从网站下载了一个压缩包,解压出来一看里面有2个文件 首先我们先打开fllllllag康康,结果发现是一串乱码,这时候第一反应就是,文件被加密了,再看fllllllag下面的gy.key文件,更 ...

  8. ubuntu16.04下内核模块解析

    一.环境如下: 1.1内核版本: jello@jello:~$ uname -a Linux jello 4.4.0-89-generic #112-Ubuntu SMP Mon Jul 31 19: ...

  9. P3386 【模板】二分图匹配 -网络流版

    二分图匹配 题目背景 二分图 感谢@一扶苏一 提供的hack数据 题目描述 给定一个二分图,结点个数分别为n,m,边数为e,求二分图最大匹配数 输入输出格式 输入格式: 第一行,n,m,e 第二至e+ ...

  10. MVC5 一套Action的登录控制流程

    流程: 用拦截器控制每一个页面请求和ajax请求,根据请求体的cookie里面是否有token判断是否登录,还必须判断该token在redis里面的缓存是否存在 组成部分: 拦截器: using Sy ...