MNIST 手写数字识别 卷积神经网络 Pytorch框架

谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打

说明

下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释,方便查看。

代码实现

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # Device configuration
#这里是个python的三元表达式,如果cuda存在的话,divice='cuda:0',否者就是'cpu'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Hyper parameters
num_epochs = 5 #全部训练集使用的次数
num_classes = 10 #全连接层输出的结果种类
batch_size = 100 #批处理的图片的个数
learning_rate = 0.001 #学习率,在梯度下降法里面的系数 # MNIST dataset
#下载训练数据集,位置放在本文件的父文件夹下的data文件夹里面,数据需要转换格式为Tensor
train_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
#下载测试集,位置放在放在本文件的父文件夹下的data文件夹里面,数据需要转换为Tensor格式
test_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=False,
transform=transforms.ToTensor()) # Data loader
#这里的shuffle(bool, optional):在每个epoch开始的时候,对数据进行重新打乱,就是重新分组
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
#对于测试集来说不需要进行从新分组(这里好像不是必须的,也可以试试每次测试分组,有意义吗?)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Convolutional neural network (two convolutional layers)
#定义一个卷积类,这里需要继承nn.Module,它是专门为神经网络设计的模块化接口
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
#调用父类的初始化函数
super(ConvNet, self).__init__()
#一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
self.layer1 = nn.Sequential(
#二维卷积层,输入通道数1,输出通道数16(相当于有16个filter,也就是16个卷积核),卷积核大小为5*5,步长为1,零填充2圈
#经过计算,可以得到卷积输出的图像的大小和输入的图像大小是等大小的,但是深度不一样,为28*28*16(16为深度),因为这里的padding抵消了卷积的缩小
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
#BatchNorm2d是卷积网络中防止梯度消失或爆炸的函数,参数是卷积的输出通道数
nn.BatchNorm2d(16),
#激活函数
nn.ReLU(),
#二维最大池化,核的大小为2,步长为2
#这样输出的图片大小就是14*14*16(16为深度)
nn.MaxPool2d(kernel_size=2, stride=2))
#两层的卷积网络,具体含义和上面相同
self.layer2 = nn.Sequential(
#这里大小也没有变化,输出依然和输出的大小相同,深度为32,所以图像为14*14*32
#但是这里的卷积核的数量是32,和输出通道数相同。
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
#下面经过池化后输出就会变成7*7*32
nn.MaxPool2d(kernel_size=2, stride=2))
#对输入数据做线性变换,第一个参数是每个输入样本的大小:7*7*32;第二个参数是输出样本的大小,这里是10,正好代表10个数,相当于类别
#第三个参数为bias(偏差),默认为True。如果为False,那么这层将不会学习偏置。
self.fc = nn.Linear(7*7*32, num_classes) #定义了每次执行的 计算步骤。 在所有的子类中都需要重写forward函数。
#
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
model = ConvNet(num_classes).to(device) # Loss and optimizer
#损失函数,
criterion = nn.CrossEntropyLoss()
#优化函数
#params (iterable)第一个参数:待优化参数的iterable或者是定义了参数组的dict
#lr (float, 可选) – 学习率(默认:1e-3)同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。
#较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Train the model
#total_step是每一轮的测试次数,这里就是60000/100=600次
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device) # Forward pass
#model(images)等价于module.forward(images)
outputs = model(images)
#根据输出的结果和标签对比,计算loss
loss = criterion(outputs, labels) # Backward and optimize
#根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;
#但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了
#将梯度初始化为零
optimizer.zero_grad()
#这里是使用反向传播计算梯度值
loss.backward()
#在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次
optimizer.step() if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item())) # Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

手写数字识别 卷积神经网络 Pytorch框架实现的更多相关文章

  1. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  4. 识别手写数字增强版100% - pytorch从入门到入道(一)

    手写数字识别,神经网络领域的“hello world”例子,通过pytorch一步步构建,通过训练与调整,达到“100%”准确率 1.快速开始 1.1 定义神经网络类,继承torch.nn.Modul ...

  5. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

  6. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. TensorFlow 卷积神经网络手写数字识别数据集介绍

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 手写数字识别 接下来将会以 MNIST 数据集为例,使用卷积层和池 ...

  9. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

随机推荐

  1. 【LOJ2316】「NOIP2017」逛公园

    [题目链接] [点击打开链接] [题目概括] 对给定\(K\),起点\(1\)到终点\(n\)中对长度为\([L,L+K]\)的路径计数. \(L\)为\(1\)到\(n\)的最短路长度. [思路要点 ...

  2. make文件基础用法

    参照:https://www.jianshu.com/p/0b2a7cb9a469 创建工作目录,包含一下文件 main.c person.c b.h c.h /*** c.h ***/ //this ...

  3. ZeroMQ+QT 字符串收发

    结合 Zeromq API函数 与 Qt 字符串QString QByteArray 实现字串收发: 发送端: zmq_msg_t msg; QString strT = “ABC汉字123”: QB ...

  4. 11.Python变量及其使用

    无论使用什么语言编程,其最终目的都是对数据进行处理.程序在编程过程中,为了处理数据更加方便,通常会将其存储在变量中. 形象地看,变量就像一个个小容器,用于“盛装”程序中的数据.除了变量,还有常量,它也 ...

  5. summernote(富文本编辑器)将附件与图片上传到自己的服务器(vue项目)

    1.上传图片至自己的服务器(这个官方都有例子,重点介绍附件上传)图片上传官方网址 // onChange callback $('#summernote').summernote({ callback ...

  6. golang string、int、int64 float 互相转换

    #string到int int,err := strconv.Atoi(string) #string到int64 int64, err := strconv.ParseInt(string, 10, ...

  7. 桥接模式下,主机能ping通虚拟机,虚拟机ping不通主机

    好像是防火墙阻止了什么东西而导致的无法ping通! 1.打开WIN7防火墙 2.选择高级设置 3.入站规则 4.找到配置文件类型为“公用”的“文件和打印共享(回显请求 – ICMPv4-In)”规则, ...

  8. nodejs 中的 cookie 及 session

    cookie-parser 插件:cookie解析,加密的操作 cookie-session 插件:session 的解析操作 http 是无状态的 cookie:在浏览器保存一些数据,每次向服务器发 ...

  9. Win7、win8、win10下实现精准截获Explorer拷贝行为

    介绍了windows下对Explorer的拷贝动作的精确截获,这个在企业数据安全dlp产品系列中减少审计的噪音很有效,方便运营人员做针对性的审计. 在企业数据安全中我通常需要监测用户的拷贝行为,特别像 ...

  10. Use an Excel RTD Server with DCOM

    费好大劲找到的文章,留存. Use an Excel RTD Server with DCOM 如何使用DCOM的Excel RTD服务器 Microsoft Office Excel 2007,Mi ...