1. 准备数据集

1.1 MNIST数据集获取:

  • torchvision.datasets接口直接下载,该接口可以直接构建数据集,推荐

  • 其他途径下载后,编写程序进行读取,然后由Datasets构建自己的数据集

​ ​ 本文使用第一种方法获取数据集,并使用Dataloader进行按批装载。如果使用程序下载失败,请将其他途径下载的MNIST数据集 [文件][解压文件] 放置在 <data/MNIST/raw/> 位置下,本文的程序及文件结构图如下:

​ ​ 其中,model文件夹用来存储每个epoch训练的模型参数,根文件夹下包含model.py用于训练模型,test.py为测试集测试,show.py为展示部分

1.2 程序部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time # 1. 准备数据集
## 1.1 使用torchvision自动下载MNIST数据集
train_data = datasets.MNIST(root='data\\',
train=True,
transform=transforms.ToTensor(),
download=True) ## 1.2 构建数据集装载器
train_loader = DataLoader(dataset=train_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("===============数据统计===============")
print("训练集样本:",train_data.__len__(), train_data.data.shape)

​ ​ 【代码解析】

  • root为存放MNIST的路径,trian=True代表下载的为训练集和训练集标签,False则代表测试集和标签

  • transforms.ToTensor()表示将shape为(H, W, C)的 numpy 数组或 img 转为shape为(C, H, W)的tensor,并将数值归一化为[0,1]

  • download为True则代表自动下载,若该文件夹下已经下载,则直接跳过下载步骤

  • shuffle=True,表示对分好的batch进行洗牌操作,drop_last=True表示对最后不足batch大小的剩余样本舍去,False表示保留

  • num_works表示每次读取的进程数,和核心数有关

​ ​ Dataset和Dataloader详细说明,请移步:[Pytorch Dataset和Dataloader 学习笔记(二)]

2. 设计网络结构

2.1 网络设计

​ ​ 网络结构如上图所示,输入图像—>卷积1—>池化1—>卷积2—>池化2—>全连接1—>全连接2—>softmax,每次卷积通道数都增加一倍,最后送入全连接层实现分类

2.2 程序部分

# 2. Design model using class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.max_pooling1 = nn.MaxPool2d(2)
self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.max_pooling2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 10) def forward(self, x):
x = self.max_pooling1(F.relu(self.conv_layer1(x)))
x = self.max_pooling2(F.relu(self.conv_layer2(x)))
x = x.view(-1, 32*7*7)
x = F.relu(self.fc1(x))
y_hat = self.fc2(x) # CrossEntropyLoss会自动激活最后一层的输出以及softmax处理
return y_hat net = Net() # 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

​ ​ 【代码解析】

  • fc1的1568维度是因为最后一次池化后的shape为32*7*7=1568

  • 在最后一层,并没有进行relu激活以及接入softmax,是因为,在CrossEntropyLoss中会自动激活最后一层的输出以及softmax处理

​ ​ CrossEntropyLoss图参考:《PyTorch深度学习实践》完结合集

​ ​ 详细网络结构搭建说明,请移步:Pytorch线性规划模型 学习笔记(一)

3. 迭代训练

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5) # 4. Training
if __name__ == "__main__":
print("Training...")
for epoch in range(20):
strat = time.time()
total_correct = 0
for x, y in train_loader:
y_hat = net(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) # 统计当前epoch下的正确个数 loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (float(total_correct) / train_data.__len__())*100
save_path = "model/net" + str(epoch+1) + ".pth"
torch.save(obj=net.state_dict(), f=save_path)
print("epoch:", str(epoch + 1) + "/20",
" \n time:", "%.1f" % (time.time() - strat) + "s"
" train_loss:", loss.item(),
" acc:%.3f%%" % acc,) print("we are done!")

​ ​ 【代码解析】

  • total_correct变量用于统计每个epoch下正确预测值的个数,每进行epoch进行一次清零
  • torch.argmax(y_hat, dim=1)用于选取y_hat下每一行的最大值(每个样本的最高得分),并返回与y相同维度的tensor
  • torch.eq(y_pre, y)用于比较两个矩阵元素是否相同,相同则返回True,不同则返回False,用于判断预测值与真实值是否相同
  • torch.save保存了每个epoch的网络权重参数

4. 测试集预测部分

# 测试模型,测试集为test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net test_data = datasets.MNIST(root='data\\',
train=False,
transform=transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data,
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=4) if __name__ == "__main__":
print("---------------预测分析---------------")
print("测试集样本:", test_data.__len__(), test_data.data.shape)
model = Net()
model.load_state_dict(torch.load("model/net20.pth"))
model.eval() total_correct = 0
for x, y in test_loader:
y_hat = model(x)
y_pre = torch.argmax(y_hat, dim=1)
total_correct += sum(torch.eq(y_pre, y)) acc = (float(total_correct) / test_data.__len__())*100
print("total_test_samples:", test_data.__len__(),
" test_acc:", "%.3f%%" % acc)

​ ​ 经过20个epoch的训练,在测试集上达到了98.590%的准确率,部分batch真实值与预测值展示如下:


5. 全部代码

链接:链接:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw

提取码:82l4

转载请说明出处

Pytorch CNN网络MNIST数字识别 [超详细记录] 学习笔记(三)的更多相关文章

  1. pytorch CNN 手写数字识别

    一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...

  2. 用pytorch做手写数字识别,识别l率达97.8%

    pytorch做手写数字识别 效果如下: 工程目录如下 第一步  数据获取 下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下 from torchvision.dataset ...

  3. muduo网络库学习笔记(三)TimerQueue定时器队列

    目录 muduo网络库学习笔记(三)TimerQueue定时器队列 Linux中的时间函数 timerfd简单使用介绍 timerfd示例 muduo中对timerfd的封装 TimerQueue的结 ...

  4. Python 工匠:使用数字与字符串的技巧学习笔记

    #Python 工匠:使用数字与字符串的技巧学习笔记#https://github.com/piglei/one-python-craftsman/blob/master/zh_CN/3-tips-o ...

  5. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  6. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  8. TensorFlow学习笔记(三)MNIST数字识别问题

    一.MNSIT数据处理 MNSIT是一个非常有名的手写体数字识别数据集.包含60000张训练图片,10000张测试图片.每张图片是28X28的数字. TonserFlow提供了一个类来处理 MNSIT ...

  9. MNIST数字识别问题

    摘自<Tensorflow:实战Google深度学习框架> import tensorflow as tf from tensorflow.examples.tutorials.mnist ...

随机推荐

  1. 一道codeforces题目引发的差分学习

    Codeforces Round #688 (Div. 2) 题目:B. Suffix Operations 题意:给定一个长为n的数组a,你可以进行两种操作:1).后缀+1;     2)后缀-1: ...

  2. 02 CTF WEB 知识梳理

    1. 工具集 基础工具 Burpsuit, Python, FireFox(Hackbar, FoxyProxy, User-Agent Swither .etc) Burpsuit 代理工具,攻击w ...

  3. 基于虹软人脸识别,实现RTMP直播推流追踪视频中所有人脸信息(C#)

    前言 大家应该都知道几个很常见的例子,比如在张学友的演唱会,在安检通道检票时,通过人像识别系统成功识别捉了好多在逃人员,被称为逃犯克星:人行横道不遵守交通规则闯红灯的路人被人脸识别系统抓拍放在大屏上以 ...

  4. 80行代码教你写一个Webpack插件并发布到npm

    1. 前言 最近在学习 Webpack 相关的原理,以前只知道 Webpack 的配置方法,但并不知道其内部流程,经过一轮的学习,感觉获益良多,为了巩固学习的内容,我决定尝试自己动手写一个插件. 这个 ...

  5. YAML/YML文件一直提示格式错误解决方法

    第一次接触yml文件,各种格式报错,但是看了几次也没看出来.其实有一个好方法,那就是直接通过yml在线格式检查 可以将yml具体内容复制到以下网址进行查询.具体报错位置会更加详细 https://ww ...

  6. git cherry-pick(不同分支的提交合并)

    git cherry-pick可以选择某一个分支中的一个或几个commit(s)来进行操作.例如,假设我们有个稳定版本的分支,叫v2.0,另外还有个开发版本的分支v3.0,我们不能直接把两个分支合并, ...

  7. Centos6.8安装mysql 步骤

    第1步.查看CentOS下是否已安装mysql 输入命令 :yum list installed | grep mysql 第2步.删除已安装mysql 输入命令:yum -y remove mysq ...

  8. IT菜鸟之路由器基础配置(静态、动态、默认路由)

    路由器:连接不同网段的设备 企业级路由和家用级路由的区别: 待机数量不同(待机量) 待机量:同时接通的终端设备的数量 待机量的值越高,路由的性能越好 别墅级路由,表示信号好,和性能无关 交换机:背板带 ...

  9. docker命令补全

    安装docker自带包: source /usr/share/bash-completion/completions/docker 缺少下面的包,TAB会报错 yum install -y bash- ...

  10. GLSL着色器,来玩

    对实现动画的前端同学们来说,canvas可以说是最自由,最能全面控制的一个动画实现载体.不但能通过javascript控制点.线.面的绘制,使用图片资源填充:还能改变输入参数作出交互动画,完全控制动画 ...