Pytorch是热门的深度学习框架之一,通过经典的MNIST 数据集进行快速的pytorch入门。

导入库

  1. from torchvision.datasets import MNIST
  2. from torchvision.transforms import ToTensor, Compose, Normalize
  3. from torch.utils.data import DataLoader
  4. import torch
  5. import torch.nn.functional as F
  6. import torch.nn as nn
  7. import os
  8. import numpy as np

准备数据集

  1. path = './data'
  2. # 使用Compose 将tensor化和正则化操作打包
  3. transform_fn = Compose([
  4. ToTensor(),
  5. Normalize(mean=(0.1307,), std=(0.3081,))
  6. ])
  7. mnist_dataset = MNIST(root=path, train=True, transform=transform_fn)
  1. data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=2, shuffle=True)
  1. # 1. 构建函数,数据集预处理
  2. BATCH_SIZE = 128
  3. TEST_BATCH_SIZE = 1000
  4. def get_dataloader(train=True, batch_size=BATCH_SIZE):
  5. '''
  6. train=True, 获取训练集
  7. train=False 获取测试集
  8. '''
  9. transform_fn = Compose([
  10. ToTensor(),
  11. Normalize(mean=(0.1307,), std=(0.3081,))
  12. ])
  13. dataset = MNIST(root='./data', train=train, transform=transform_fn)
  14. data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
  15. return data_loader

构建模型


  1. class MnistModel(nn.Module):
  2. def __init__(self):
  3. super().__init__() # 继承父类
  4. self.fc1 = nn.Linear(1*28*28, 28) # 添加全连接层
  5. self.fc2 = nn.Linear(28, 10)
  6. def forward(self, input):
  7. x = input.view(-1, 1*28*28)
  8. x = self.fc1(x)
  9. x = F.relu(x)
  10. out = self.fc2(x)
  11. return F.log_softmax(out, dim=-1) # log_softmax 与 nll_loss合用,计算交叉熵

模型训练

  1. mnist_model = MnistModel()
  2. optimizer = torch.optim.Adam(params=mnist_model.parameters(), lr=0.001)
  3. # 如果有模型则加载
  4. if os.path.exists('./model'):
  5. mnist_model.load_state_dict(torch.load('model/mnist_model.pkl'))
  6. optimizer.load_state_dict(torch.load('model/optimizer.pkl'))
  1. def train(epoch):
  2. data_loader = get_dataloader()
  3. for index, (data, target) in enumerate(data_loader):
  4. optimizer.zero_grad() # 梯度先清零
  5. output = mnist_model(data)
  6. loss = F.nll_loss(output, target)
  7. loss.backward() # 误差反向传播计算
  8. optimizer.step() # 更新梯度
  9. if index % 100 == 0:
  10. # 保存训练模型
  11. torch.save(mnist_model.state_dict(), 'model/mnist_model.pkl')
  12. torch.save(optimizer.state_dict(), 'model/optimizer.pkl')
  13. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
  14. epoch, index * len(data), len(data_loader.dataset),
  15. 100. * index / len(data_loader), loss.item()))
  1. for i in range(epoch=5):
  2. train(i)
  1. Train Epoch: 0 [0/60000 (0%)] Loss: 0.023078
  2. Train Epoch: 0 [12800/60000 (21%)] Loss: 0.019347
  3. Train Epoch: 0 [25600/60000 (43%)] Loss: 0.105870
  4. Train Epoch: 0 [38400/60000 (64%)] Loss: 0.050866
  5. Train Epoch: 0 [51200/60000 (85%)] Loss: 0.097995
  6. Train Epoch: 1 [0/60000 (0%)] Loss: 0.108337
  7. Train Epoch: 1 [12800/60000 (21%)] Loss: 0.071196
  8. Train Epoch: 1 [25600/60000 (43%)] Loss: 0.022856
  9. Train Epoch: 1 [38400/60000 (64%)] Loss: 0.028392
  10. Train Epoch: 1 [51200/60000 (85%)] Loss: 0.070508
  11. Train Epoch: 2 [0/60000 (0%)] Loss: 0.037416
  12. Train Epoch: 2 [12800/60000 (21%)] Loss: 0.075977
  13. Train Epoch: 2 [25600/60000 (43%)] Loss: 0.024356
  14. Train Epoch: 2 [38400/60000 (64%)] Loss: 0.042203
  15. Train Epoch: 2 [51200/60000 (85%)] Loss: 0.020883
  16. Train Epoch: 3 [0/60000 (0%)] Loss: 0.023487
  17. Train Epoch: 3 [12800/60000 (21%)] Loss: 0.024403
  18. Train Epoch: 3 [25600/60000 (43%)] Loss: 0.073619
  19. Train Epoch: 3 [38400/60000 (64%)] Loss: 0.074042
  20. Train Epoch: 3 [51200/60000 (85%)] Loss: 0.036283
  21. Train Epoch: 4 [0/60000 (0%)] Loss: 0.021305
  22. Train Epoch: 4 [12800/60000 (21%)] Loss: 0.062750
  23. Train Epoch: 4 [25600/60000 (43%)] Loss: 0.016911
  24. Train Epoch: 4 [38400/60000 (64%)] Loss: 0.039599
  25. Train Epoch: 4 [51200/60000 (85%)] Loss: 0.026689

模型测试

  1. def test():
  2. loss_list = []
  3. acc_list = []
  4. test_loader = get_dataloader(train=False, batch_size = TEST_BATCH_SIZE)
  5. mnist_model.eval() # 设为评估模式
  6. for index, (data, target) in enumerate(test_loader):
  7. with torch.no_grad():
  8. out = mnist_model(data)
  9. loss = F.nll_loss(out, target)
  10. loss_list.append(loss)
  11. pred = out.data.max(1)[1]
  12. acc = pred.eq(target).float().mean() # eq()函数用于将两个tensor中的元素对比,返回布尔值
  13. acc_list.append(acc)
  14. print('平均准确率, 平均损失', np.mean(acc_list), np.mean(loss_list))
  1. test()
  1. 平均准确率, 平均损失 0.9662777 0.12309619

Pytorch实现MNIST手写数字识别的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  3. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  4. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  5. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  6. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  7. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

  8. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  9. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

随机推荐

  1. 终极解决方案——sbt配置阿里镜像源,解决sbt下载慢,dump project structure from sbt耗时问题

    #sbt下载慢的问题 默认情况下,sbt使用mvn2仓库下载依赖,如下载scalatest时,idea的sbtshell 显示如下url https://repo1.maven.org/maven2/ ...

  2. 面试刷题25:jvm的垃圾收集算法?

    垃圾收集是java语言的亮点,大大提高了开发人员的效率. 垃圾收集即GC,当内存不足的时候触发,不同的jvm版本算法和机制都有差别. 我是李福春,我在准备面试,今天的问题是: jvm的垃圾回收算法有哪 ...

  3. 说说自己为什么用Mac不用Win系统?

    原本Mac和Win系统各有优劣,但偏偏最近有人误导身边的朋友说"学编程肯定是Windows系统呀,Mac不行的",又不给出有说服力的理由,于是我心有愤懑,正好趁机总结一下自己对于两 ...

  4. GB2312,GBK和UTF-8的区别

    GBK GBK包含全部中文字符, GBK的文字编码是双字节来表示的,即不论中.英文字符均使用双字节来表示,只不过为区分中文,将其最高位都定成1.至于UTF-8编码则是用以解决国际上字符的一种多字节编码 ...

  5. PyTorch Hub发布!一行代码调用最潮模型,图灵奖得主强推

    为了调用各种经典机器学习模型,今后你不必重复造轮子了. 刚刚,Facebook宣布推出PyTorch Hub,一个包含计算机视觉.自然语言处理领域的诸多经典模型的聚合中心,让你调用起来更方便. 有多方 ...

  6. 一 JVM垃圾回收模型

    一 JVM垃圾回收模型 一. GC算法 1.1 标记-清除算法(Mark-Sweep) 算法分为"标记"和"清除"两个阶段首先标记出所有需要回收的对象,然后回收 ...

  7. 【Pytest03】全网最全最新的Pytest框架fixture应用篇(1)

    fixtrue修饰器标记的方法通常用于在其他函数.模块.类或者整个工程调用时会优先执行,通常会被用于完成预置处理和重复操作.例如:登录,执行SQL等操作. 完整方法如下:fixture(scope=' ...

  8. SpringBoot,SpringMvc, SpringCloud

    1,SpringBoot VS SpringMvc VS SpringBoot SpringBoot: SpringBoot 是一个快速开发的框架,能够快速的整合第三方框架,简化XML配置,全部采用注 ...

  9. [poj1061]青蛙的约会<扩展欧几里得>

    题目链接:http://poj.org/problem?id=1061 其实欧几里得我一直都知道,只是扩展欧几里得有点蒙,所以写了一道扩展欧几里得裸题. 欧几里得算法就是辗转相除法,求两个数的最大公约 ...

  10. Python Tkinter Grid布局管理器详解

    Grid(网格)布局管理器会将控件放置到一个二维的表格里.主控件被分割成一系列的行和列,表格中的每个单元(cell)都可以放置一个控件. 注意:不要试图在一个主窗口中混合使用pack和grid (1) ...