写在前面

由于MLP的实现框架已经非常完善,网上搜到的代码大都大同小异,而且MLP的实现是deeplearning学习过程中较为基础的一个实验。因此完全可以找一份源码以参考,重点在于照着源码手敲一遍,以熟悉pytorch的基本操作。

实验要求

熟悉pytorch的基本操作:用pytorch实现MLP,并在MNIST数据集上进行训练

环境配置

实验环境如下:

  • Win10
  • python3.8
  • Anaconda3
  • Cuda10.2 + cudnn v7
  • GPU : NVIDIA GeForce MX250

配置环境的过程中遇到了一些问题,解决方案如下:

  1. anaconda下载过慢

    使用清华镜像源,直接百度搜索即可

  2. pytorch安装失败

    这里我首先使用的是pip的安装方法,失败多次后尝试了使用anaconda,然后配置了清华镜像源,最后成功。参考的教程如下:

    win10快速安装pytorch,清华镜像源

    当然也可以直接去pytorch官网下载所需版本的whl文件,然后手动pip安装。由于这种方式我已经学会了,为了学习anaconda,所以没有采用这种方式。具体方式可以百度如何使用whl。顺便贴下pytorch的whl的下载页面

注意:pytorch的版本是要严格对应是否使用GPU、python版本、cuda版本的,如需手动下载pytorch的安装包,需搞懂其whl文件的命名格式

另外还学习了anaconda的一些基本操作与原理,参考如下:

Anaconda完全入门指南

实验过程

最终代码见github:hit-deeplearning-1

首先设置一些全局变量,加载数据。batch_size决定了每次向网络中输入的样本数,epoch决定了整个数据集的迭代次数,具体作用与大小如何调整可参考附录中的博客。

将数据读入,如果数据不存在于本地,则可以自动从网上下载,并保存在本地的data文件夹下。

  1. #一次取出的训练样本数
  2. batch_size = 16
  3. # epoch 的数目
  4. n_epochs = 10
  5. #读取数据
  6. train_data = datasets.MNIST(root="./data", train=True, download=True,transform=transforms.ToTensor())
  7. test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
  8. #创建数据加载器
  9. train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, num_workers = 0)
  10. test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, num_workers = 0)

接下来是创建MLP模型,关于如何创建一个模型,可以参考附录中的博客,总之创建模型模板,训练模板都是固定的。

其中LinearviewCrossEntropyLossSGD的用法需重点关注。查看官方文档或博客解决。

这两条语句将数据放到了GPU上,同理测试的时候也要这样做。

  1. data = data.cuda()
  2. target = target.cuda()
  1. class MLP(nn.Module):
  2. def __init__(self):
  3. #继承自父类
  4. super(MLP, self).__init__()
  5. #创建一个三层的网络
  6. #输入的28*28为图片大小,输出的10为数字的类别数
  7. hidden_first = 512
  8. hidden_second = 512
  9. self.first = nn.Linear(in_features=28*28, out_features=hidden_first)
  10. self.second = nn.Linear(in_features=hidden_first, out_features=hidden_second)
  11. self.third = nn.Linear(in_features=hidden_second, out_features=10)
  12. def forward(self, data):
  13. #先将图片数据转化为1*784的张量
  14. data = data.view(-1, 28*28)
  15. data = F.relu(self.first(data))
  16. data = F.relu((self.second(data)))
  17. data = F.log_softmax(self.third(data), dim = 1)
  18. return data
  19. def train():
  20. # 定义损失函数和优化器
  21. lossfunc = torch.nn.CrossEntropyLoss().cuda()
  22. #lossfunc = torch.nn.CrossEntropyLoss()
  23. optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)
  24. # 开始训练
  25. for epoch in range(n_epochs):
  26. train_loss = 0.0
  27. for data, target in train_loader:
  28. optimizer.zero_grad()
  29. #将数据放至GPU并计算输出
  30. data = data.cuda()
  31. target = target.cuda()
  32. output = model(data)
  33. #计算误差
  34. loss = lossfunc(output, target)
  35. #反向传播
  36. loss.backward()
  37. #将参数更新至网络中
  38. optimizer.step()
  39. #计算误差
  40. train_loss += loss.item() * data.size(0)
  41. train_loss = train_loss / len(train_loader.dataset)
  42. print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch + 1, train_loss))
  43. # 每遍历一遍数据集,测试一下准确率
  44. test()
  45. #最后将模型保存
  46. path = "model.pt"
  47. torch.save(model, path)

test程序不再贴出,直接调用了一个很常用的test程序。

最后是主程序,在这里将模型放到GPU上。

  1. model = MLP()
  2. #将模型放到GPU上
  3. model = model.cuda()
  4. train()

实验结果

实验结果如下,可以看到,当对数据迭代训练十次时,准确率已经可以达到97%

分别运行了两次,第一次没有使用cuda加速,第二次使用了cuda加速,任务管理器分别显示如下:

可以看到,未使用cuda加速时,cpu占用率达到了100%,而GPU的使用率为0;而使用cuda加速时,cpu占用率只有49%,而GPU使用率为1%。这里GPU使用率较低的原因很多,比如我程序中batch_size设置的较小,另外只将数据和模型放到了GPU上,cpu上仍有部分代码与数据。经简单测试,使用cuda的训练时间在2:30左右,不使用cuda的训练时间在3:40左右。

参考博客

使用Pytorch构建MLP模型实现MNIST手写数字识别

如何创建自定义模型

pytorch教程之nn.Module类详解——使用Module类来自定义网络层

epoch和batch是什么

深度学习 | 三个概念:Epoch, Batch, Iteration

如何用GPU加速

从头学pytorch(十三):使用GPU做计算

PyTorch如何使用GPU加速(CPU与GPU数据的相互转换)

保存模型

PyTorch模型保存与加载

pytorch实现MLP并在MNIST数据集上验证的更多相关文章

  1. MNIST数据集上卷积神经网络的简单实现(使用PyTorch)

    设计的CNN模型包括一个输入层,输入的是MNIST数据集中28*28*1的灰度图 两个卷积层, 第一层卷积层使用6个3*3的kernel进行filter,步长为1,填充1.这样得到的尺寸是(28+1* ...

  2. caffe在windows编译project及执行mnist数据集測试

    caffe在windows上的配置和编译能够參考例如以下的博客: http://blog.csdn.net/joshua_1988/article/details/45036993 http://bl ...

  3. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

  4. 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集

    原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...

  5. pytorch 加载mnist数据集报错not gzip file

    利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...

  6. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下面修改 ...

  7. 基于Keras 的VGG16神经网络模型的Mnist数据集识别并使用GPU加速

    这段话放在前面:之前一种用的Pytorch,用着还挺爽,感觉挺方便的,但是在最近文献的时候,很多实验都是基于Google 的Keras的,所以抽空学了下Keras,学了之后才发现Keras相比Pyto ...

  8. SGD与Adam识别MNIST数据集

    几种常见的优化函数比较:https://blog.csdn.net/w113691/article/details/82631097 ''' 基于Adam识别MNIST数据集 ''' import t ...

  9. MXNet学习-第一个例子:训练MNIST数据集

    一个门外汉写的MXNET跑MNIST的例子,三层全连接层最后验证率是97%左右,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导入需要的模块 import numpy as np #num ...

随机推荐

  1. mac主机无法访问虚拟机中的Ubuntu运行的web服务

    第一点: 检查主机和虚拟机之间是否连通: 在mac主机中ping 虚拟机ip 虚拟机ip可以在虚拟机命令行中输入 ifconfig查看 第二点: 如果不能ping通,改变虚拟机的网络连接方式为:桥接模 ...

  2. 关于 IDEA 启动 springboot 项目异常 - Disconnected from the target VM, address: '127.0.0.1:59770', transport: 'socket'

    关于 IDEA 启动 springboot 项目异常 - Disconnected from the target VM, address: '127.0.0.1:59770', transport: ...

  3. 广告行业中那些趣事系列8:详解BERT中分类器源码

    最新最全的文章请关注我的微信公众号:数据拾光者. 摘要:BERT是近几年NLP领域中具有里程碑意义的存在.因为效果好和应用范围广所以被广泛应用于科学研究和工程项目中.广告系列中前几篇文章有从理论的方面 ...

  4. WEB安全——XML注入

    浅析XML注入 认识XML DTD XML注入 XPath注入 XSL和XSLT注入 前言前段时间学习了.net,通过更改XML让连接数据库变得更方便,简单易懂,上手无压力,便对XML注入这块挺感兴趣 ...

  5. Reface.NPI 方法名称解析规则详解

    在上次的文章中简单介绍了 Reface.NPI 中的功能. 本期,将对这方法名称解析规则进行详细的解释和说明, 以便开发者可以完整的使用 Reface.NPI 中的各种功能. 基本规则 方法名称以 I ...

  6. 认识STM32芯片

    STM32中的ST指的是意法半导体,M是Microelectronics的缩写,32表示32位,即意法半导体公司开发的32位微控制器 ST官网:https://www.st.com/content/s ...

  7. touch多点触摸事件

    touch--单点 targetTouches. changeTouches 多点: targetTouches--当前物体上的手指数 *不同物体上的手指不会互相干扰 不需要做多点触摸的时候---平均 ...

  8. 整数逆序输出 Python

    输入形式:123  输出形式:321 输入形式:120 输出形式:21  (整数不能以0打头) 输入形式:-123 输出形式:-321 代码: a=int(input()) b=0 if a<0 ...

  9. yii2框架学习笔记

    1.去掉yii2模版默认的头部和脚部的两种方法: (1) 第一种 $this->layout = false; $this->render('index'); (2) 第二种(partia ...

  10. ftl中几个特殊的用法

    @ 注意${}为变量的渲染显示,即先计算后打印出来,而<>里面为定义等操作符的定义 ,而首尾2个<>中间部分一般为计算打印部分 @数据模型中如果不是以map数据来封装的,而是直 ...