任务目标

对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率。(最终本文达到了\(99.36\%\))

使用的库的版本:

  1. python:3.8.12
  2. pytorch:1.5.1

代码地址GitHub:https://github.com/xiaohuiduan/deeplearning-study/tree/main/手写数字识别

数据集介绍

MNIST数字数据集来自MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

在torchvision中自带了关于MNIST的数据集。如果直接使用自带的数据集,能方便不少。关于具体使用,可参考:PyTorch初探MNIST数据集 - 知乎 (zhihu.com)

在Lecun的提供的MNIST数据集,有如下4个文件(images文件和labels文件):

training set包含了60000张手写数字图片,test set包含了10000张图片。在images文件和labels文件中,数据是使用二进制进行保存的。

图像文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;

  • 第5-8个byte存的是number of images,即图像数量60000;

  • 第9-12个byte存的是每张图片行数/高度,即28;

  • 第13-16个byte存的是每张图片的列数/宽度,即28。

  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

标签文件的二进制储存格式如下(参考python处理MNIST数据集 - 简书 (jianshu.com)):

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;

  • 第5-8个byte存的是number of items,即label数量60000;

  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

二进制文件的Python处理代码:

import numpy as np
def read_image(file_path):
"""读取MNIST图片 Args:
file_path (str): 图片文件位置 Returns:
list: 图片列表
"""
with open(file_path,'rb') as f:
file = f.read()
img_num = int.from_bytes(file[4:8],byteorder='big') #图片数量
img_h = int.from_bytes(file[8:12],byteorder='big') #图片h
img_w = int.from_bytes(file[12:16],byteorder='big') #图片w
img_data = []
file = file[16:]
data_len = img_h*img_w for i in range(img_num):
data = [item/255 for item in file[i*data_len:(i+1)*data_len]]
img_data.append(np.array(data).reshape(img_h,img_w)) return img_data def read_label(file_path):
with open(file_path,'rb') as f:
file = f.read()
label_num = int.from_bytes(file[4:8],byteorder='big') #label的数量
file = file[8:]
label_data = []
for i in range(label_num):
label_data.append(file[i])
return label_data train_img = read_image("mnist/train/train-images.idx3-ubyte")
train_label = read_label("mnist/train/train-labels.idx1-ubyte") # test_img = read_image("mnist/test/t10k-images.idx3-ubyte")
# test_label = read_label("mnist/test/t10k-labels.idx1-ubyte")

数据集部分数据如下所示:

数据集划分

在深度学习中,需要将trainset划分成训练集验证集。最终使用测试集去验证模型的结果。

训练集:用来训练模型参数。

验证集:验证模型的状况和收敛情况。

测试集:验证模型结果。

形象上来说训练集就像是学生的课本,学生 根据课本里的内容来掌握知识,验证集就像是作业,通过作业可以知道 不同学生学习情况、进步的速度快慢,而最终的测试集就像是考试,考的题是平常都没有见过,考察学生举一反三的能力。

来源:训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)

因此,需要将上文中的train_img,train_label进行划分,划分为训练集验证集。这里使用sklearn中的train_test_split进行划分,训练集和测试集的比例为\(8:2\)。

from sklearn.model_selection import train_test_split
train_img,valid_img,train_label,valid_label = train_test_split(train_img,train_label,test_size=0.2,shuffle=True)

网络结构

根据网络的权重,Netron生成的网络结构图如下,图中详细的介绍了每一层的结构参数。

网络结构的简洁图如下所示,网络一共由3层卷积层(每层卷积分别由Conv2d,BatchNorm2d,MaxPool2d和Dropout构成)和2个全连接层构成。

Pytorch代码如下:

class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(1,32,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2),
nn.Dropout(0.25)
)
self.conv_2 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3,padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.conv_3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2),
nn.Dropout(0.25),
) self.fc = nn.Sequential(
nn.Linear(512,128),
nn.Linear(128,10)
) def forward(self,x): #x (3,28,28)
x = self.conv_1(x) #x (32,14,14)
x = self.conv_2(x) #x (64,7,7)
x = self.conv_3(x) #x (128,4,4)
x = x.view(x.size(0),-1) x = self.fc(x)
return F.log_softmax(x,dim=1)
myNet = MyNet().to(device)

训练集以及验证集结果

大概经过300个epoch训练,验证集便能够达到\(99.9\%\)以上的正确率。

训练集的Loss曲线:

测试集结果

测试集使用训练400个epoch之后的模型进行预测。其最终预测的正确率为:\(99.36 \%\)。实际上,大概300个epoch就能够在测试集达到\(99\%\)以上的正确率。

参考

  1. MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
  2. MNIST — Torchvision 0.12 documentation (pytorch.org)
  3. python处理MNIST数据集 - 简书 (jianshu.com)
  4. 训练集(train)验证集(validation)测试集(test)与交叉验证法 - 知乎 (zhihu.com)
  5. sklearn.model_selection.train_test_split — scikit-learn 1.0.2 documentation
  6. Netron

深度学习(一)之MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  3. 深度学习之 cnn 进行 CIFAR10 分类

    深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...

  4. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  5. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  6. 深度学习笔记(一):logistic分类【转】

    本文转载自:https://blog.csdn.net/u014595019/article/details/52554582 这个系列主要记录我在学习各个深度学习算法时候的笔记,因为之前已经学过大概 ...

  7. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  8. 自己动手实现深度学习框架-8 RNN文本分类和文本生成模型

    代码仓库: https://github.com/brandonlyg/cute-dl 目标         上阶段cute-dl已经可以构建基础的RNN模型.但对文本相模型的支持不够友好, 这个阶段 ...

  9. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

随机推荐

  1. 编译安装haproxy

    一.安装lua环境 1.1 安装依赖包 [root@centos7 ~]# yum install gcc readline-devel 1.2 下线lua源码包并解压 [root@centos7 ~ ...

  2. Linux重定向输出到以当前时间命名的文件 / date命令格式化输出

    1. 利用date命令重定向到以当前时间命名的文件 例如: ls -l > mylog_$(date +"%Y-%m-%d_%H-%M-%S").log 或: ls -l & ...

  3. Oracle用户创建、删除和授权等方法总结

    一.查看用户及权限 1.查询所有用户: 1.1.查看所有用户基本信息 select * from all_users; 1.2.查看所有用户相信信息 select * from dba_users; ...

  4. 微服务从代码到k8s部署应有尽有系列(二、网关)

    我们用一个系列来讲解从需求到上线.从代码到k8s部署.从日志到监控等各个方面的微服务完整实践. 整个项目使用了go-zero开发的微服务,基本包含了go-zero以及相关go-zero作者开发的一些中 ...

  5. Solution -「Code+#2」「洛谷 P4033」白金元首与独舞

    \(\mathcal{Description}\)   link.   给定一个 \(n\times m\) 的网格图,一些格子指定了走出该格的方向(上下左右),而有 \(k\) 格可以任意指定走出方 ...

  6. Solution -「CF 1025G」Company Acquisitions

    \(\mathcal{Description}\)   Link.   \(n\) 个公司,每个公司可能独立或者附属于另一个公司.初始时,每个公司附属于 \(a_i\)(\(a_i=-1\) 表示该公 ...

  7. Solution -「CF 1342E」Placing Rooks

    \(\mathcal{Description}\)   Link.   在一个 \(n\times n\) 的国际象棋棋盘上摆 \(n\) 个车,求满足: 所有格子都可以被攻击到. 恰好存在 \(k\ ...

  8. python2发微信脚本

    #!/usr/bin/env python # -*- coding: utf-8 -*- import urllib,urllib2,json import sys reload(sys) sys. ...

  9. Linux mysql8.0.11安装

    准备:检查是否已安装过mysql,若有便删除(linux系统自带的) rpm -qa | grep mariadb rpm -e nodeps mariadb-libs-5.5.56-2.el7.x8 ...

  10. 轩辕展览-VR虚拟展厅设计的好处和优势是什么?

    yu情仍在继续,实体展厅很糟糕,在过去两年之中,越来越多的实体展厅因闲置而关闭,线上VR虚拟展厅设计逐渐走出圈子,凭借云展示的优势和国家政策的支持,登上展示和销售的旗帜. 产品线上展厅的优势是什么1. ...