import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt # 超参数
# Hyper Parameters
# 训练整批数据多少次, 为了节约时间, 只训练一次
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height 时间步数 / 图片高度
INPUT_SIZE = 28 # rnn input size / image width 每步输入值 / 图片每行像素
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist 手写数字
# Mnist digital dataset
train_data = dsets.MNIST(
root='./mnist/', # 保存或者提取位置
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
) # plot one example
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() # 数据加载
# Data Loader for easy mini-batch return in training 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试数据
# convert test data into Variable, pick 2000 samples to speed up testing
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array # 用一个 class 来建立 RNN 模型.
# 这个 RNN 整体流程:
# (input0, state0) -> LSTM -> (output0, state1);
# (input1, state1) -> LSTM -> (output1, state2);
# …
# (inputN, stateN)-> LSTM -> (outputN, stateN+1);
# outputN -> Linear -> prediction.
# 通过LSTM分析每一时刻的值, 并且将这一时刻和前面时刻的理解合并在一起, 生成当前时刻对前面数据的理解或记忆. 传递这种理解给下一时刻分析.
# 定义神经网络
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns LSTM 效果要比 nn.RNN() 好多了
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer 有几层 RNN layers
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
)
# 输出层
self.out = nn.Linear(64, 10) def forward(self, x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) # LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # 选取最后一个时间点的 r_out 输出
# choose r_out at the last time step
out = self.out(r_out[:, -1, :]) # 这里 r_out[:, -1, :] 的值也是 h_n 的值
return out rnn = RNN()
print(rnn)
# 选择优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted # 训练和测试
# 将图片数据看成一个时间上的连续数据, 每一行的像素点都是这个时刻的输入,
# 读完整张图片就是从上而下的读完了每行的像素点. 然后我们就可以拿出 RNN 在最后一步的分析值判断图片是哪一类了.
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data
b_x = b_x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size) output = rnn(b_x) # rnn output # 喂给 rnn net 训练数据 b_x, 输出预测值
loss = loss_func(output, b_y) # cross entropy loss # 计算两者的误差
optimizer.zero_grad() # clear gradients for this training step # 清空上一步的残余更新参数值
loss.backward() # backpropagation, compute gradients # 误差反向传播, 计算参数更新值
optimizer.step() # apply gradients # 将参数更新值施加到 rnn net 的 parameters 上 if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) # print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

pytorch1.0实现RNN-LSTM for Classification的更多相关文章

  1. pytorch1.0实现RNN for Regression

    import torch from torch import nn import numpy as np import matplotlib.pyplot as plt # 超参数 # Hyper P ...

  2. RNN/LSTM/GRU/seq2seq公式推导

    概括:RNN 适用于处理序列数据用于预测,但却受到短时记忆的制约.LSTM 和 GRU 采用门结构来克服短时记忆的影响.门结构可以调节流经序列链的信息流.LSTM 和 GRU 被广泛地应用到语音识别. ...

  3. 时间序列(六): 炙手可热的RNN: LSTM

    目录 炙手可热的LSTM 引言 RNN的问题 恐怖的指数函数 梯度消失* 解决方案 LSTM 设计初衷 LSTM原理 门限控制* LSTM 的 BPTT 参考文献: 炙手可热的LSTM 引言 上一讲说 ...

  4. [NL系列] RNN & LSTM 网络结构及应用

    http://www.jianshu.com/p/f3bde26febed/ 这篇是 The Unreasonable Effectiveness of Recurrent Neural Networ ...

  5. RNN,LSTM,GRU基本原理的个人理解

    记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解 RNN Recurrent Neural Networks,循环神经网络 (注意区别于recursive neura ...

  6. 用pytorch1.0快速搭建简单的神经网络

    用pytorch1.0搭建简单的神经网络 import torch import torch.nn.functional as F # 包含激励函数 # 建立神经网络 # 先定义所有的层属性(__in ...

  7. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  8. 深度学习中的序列模型演变及学习笔记(含RNN/LSTM/GRU/Seq2Seq/Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  9. pytorch1.0 用torch script导出模型

    python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发.低延迟的要求,存在需要用c++接口的情况. ...

随机推荐

  1. C# 百度API地址坐标互相转换

    通过C#代码将地址字符串转为经纬度坐标,或者将经纬度转为具体的地址字符串,在不通外网的项目中是有需求的. 具体步骤: 一.创建BaiduMapHelper,用于定义地址信息和请求. public st ...

  2. Android 照片上传

    解释全在代码中: // 拍照上传 private OnClickListener mUploadClickListener = new OnClickListener() { public void ...

  3. 用vue做的购物车结算的功能

    <!-- 占位 --> <template> <div> <div class="product_table"> <div c ...

  4. FLUENT求解传热系数surfaceheattransfercoef.的参考值的设置【转载】

    转载自:http://blog.sina.com.cn/s/blog_7ef78d170101ch30.html surface heat transfer coef. 计算公式: FLUENT求解传 ...

  5. JAVA基础之访问控制权限

    包:库单元 1.当编写一个Java源代码文件时,此文件通常被称为编译单元(有时也被称为转译单元). 2.每个编译单元都必须有一个后缀名.java,而在编译单元内则可以有一个public类,该类名称必须 ...

  6. JVM 类加载器命名空间深度解析与实例分析

    一.创建Sample 1.创建实例 public class MyPerson { private MyPerson myPerson; public void setMyPerson(Object ...

  7. web-linux-shell实现 阿里方案canvas+wss。

    wss://shell.aliyun.com/terminals?cols=92&rows=35 [root@webshell ~]# python Python 3.6.8 (default ...

  8. 对pdf中的图片进行自动识别

    对pdf中的图片进行自动识别 商务合作,科技咨询,版权转让:向日葵,135—4855__4328,xiexiaokui#qq.com 原理:增强扫描 效果:自动识别所有图片中的文字,可以选择.复制,进 ...

  9. ISO/IEC 9899:2011 条款6.7.2——类型说明符

    6.7.2 类型说明符 语法 1.type-specifier: void char short int long float double signed unsigned _Bool _Comple ...

  10. k8s记录-kubeadm安装(一)(转载)

    配置 kubeadm 概述 安装 kubernetes 主要是安装它的各个镜像,而 kubeadm 已经为我们集成好了运行 kubernetes 所需的基本镜像.但由于国内的网络原因,在搭建环境时,无 ...