import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt # 超参数
# Hyper Parameters
# 训练整批数据多少次, 为了节约时间, 只训练一次
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height 时间步数 / 图片高度
INPUT_SIZE = 28 # rnn input size / image width 每步输入值 / 图片每行像素
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist 手写数字
# Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
) # plot one example
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() # 数据加载
# Data Loader for easy mini-batch return in training 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试数据
# convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array # 用一个 class 来建立 RNN 模型.
# 这个 RNN 整体流程:
# (input0, state0) -> LSTM -> (output0, state1);
# (input1, state1) -> LSTM -> (output1, state2);
# …
# (inputN, stateN)-> LSTM -> (outputN, stateN+1);
# outputN -> Linear -> prediction.
# 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
# 定义神经网络
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns LSTM 效果要比 nn.RNN() 好多了
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer 有几层 RNN layers
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
)
# 输出层
self.out = nn.Linear(64, 10) def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) # LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # 选取最后一个时间点的 r_out 输出
# choose r_out at the last time step
out = self.out(r_out[:, -1, :]) # 这里 r_out[:, -1, :] 的值也是 h_n 的值
return out rnn = RNN()
print(rnn)
# 选择优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # 训练和测试
# 将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入,
# 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了.
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data
b_x = b_x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size) output = rnn(b_x) # rnn output # 喂给 rnn net 训练数据 b_x, 输出预测值
loss = loss_func(output, b_y) # cross entropy loss # 计算两者的误差
optimizer.zero_grad() # clear gradients for this training step # 清空上一步的残余更新参数值
loss.backward() # backpropagation, compute gradients # 误差反向传播, 计算参数更新值
optimizer.step() # apply gradients # 将参数更新值施加到 rnn net 的 parameters 上 if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) # print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

pytorch1.0实现RNN-LSTM for Classification的更多相关文章

  1. pytorch1.0实现RNN for Regression

    import torch from torch import nn import numpy as np import matplotlib.pyplot as plt # 超参数 # Hyper P ...

  2. RNN/LSTM/GRU/seq2seq公式推导

    概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...

  3. 时间序列(六): 炙手可热的RNN: LSTM

    目录 炙手可热的LSTM 引言 RNN的问题 恐怖的指数函数 梯度消失* 解决方案 LSTM 设计初衷 LSTM原理 门限控制* LSTM 的 BPTT 参考文献: 炙手可热的LSTM 引言 上一讲说 ...

  4. [NL系列] RNN & LSTM 网络结构及应用

    http://www.jianshu.com/p/f3bde26febed/ 这篇是 The Unreasonable Effectiveness of Recurrent Neural Networ ...

  5. RNN,LSTM,GRU基本原理的个人理解

    记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...

  6. 用pytorch1.0快速搭建简单的神经网络

    用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...

  7. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  8. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  9. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

随机推荐

  1. 4、http协议之二

    URL(Unifrom Resource Locator)简述 相对URL 从当前页面(同一个站点内或同一个文章内引用) 绝对URL 从当前页面或其他页面跳转而来(跨站引用) HTTPD版本<0 ...

  2. ZR#1004

    ZR#1004 解法: 对于 $ (x^2 + y)^2 \equiv (x^2 - y)^2 + 1 \pmod p $ 化简并整理得 $ 4x^2y \equiv 1 \pmod p $ 即 $ ...

  3. 怎么把分化成元,并且保留两位小数,用vue来做

    <el-table-column prop="amount" label="申请提现金额" width="120" align=&qu ...

  4. GO语言网络编程

    socket编程 Socket是BSD UNIX的进程通信机制,通常也称作"套接字",用于描述IP地址和端口,是一个通信链的句柄.Socket可以理解为TCP/IP网络的API,它 ...

  5. Java中对象并不是都在堆上分配内存的

    转(https://blog.51cto.com/13906751/2153924) 前段时间,给星球的球友们专门码了一篇文章<深入分析Java的编译原理>,其中深入的介绍了Java中的j ...

  6. 康哲20191114-1 每周例行报告kz404

    此作业的要求参见https://edu.cnblogs.com/campus/nenu/2019fall/homework/10004 本周PSP  本周进度条  本周折线图  饼状图

  7. elasticsearch使用BulkProcessor批量入库数据

    在解决es入库问题上,之前使用过rest方式,经过一段时间的测试发现千万级别的数据会存在10至上百条数据的丢失问题, 在需要保证数据的准确性的场景下,rest方式并不能保证结果的准确性,因此采用了el ...

  8. Eclipse 的快捷键以及文档注释、多行注释的快捷键 一、多行注释快捷键

    一.多行注释快捷键 1.选中你要加注释的区域,用ctrl+shift+C 或者ctrl+/ 会加上//注释2.先把你要注释的东西选中,用shit+ctrl+/ 会加上/*    */注释 3.以上快捷 ...

  9. markdown2的key

    分享一个MarkDown2的授权key   邮箱地址: Soar360@live.com 授权秘钥: GBPduHjWfJU1mZqcPM3BikjYKF6xKhlKIys3i1MU2eJHqWGIm ...

  10. ISO/IEC 9899:2011 条款6.7.6——声明符

    6.7.6 声明符 语法 1.declarator: pointeropt    direct-declarator direct-declarator: identifier (    declar ...