手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架
谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打
说明
下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释,方便查看。
代码实现
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
#这里是个python的三元表达式,如果cuda存在的话,divice='cuda:0',否者就是'cpu'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Hyper parameters
num_epochs = 5 #全部训练集使用的次数
num_classes = 10 #全连接层输出的结果种类
batch_size = 100 #批处理的图片的个数
learning_rate = 0.001 #学习率,在梯度下降法里面的系数
# MNIST dataset
#下载训练数据集,位置放在本文件的父文件夹下的data文件夹里面,数据需要转换格式为Tensor
train_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
#下载测试集,位置放在放在本文件的父文件夹下的data文件夹里面,数据需要转换为Tensor格式
test_dataset = torchvision.datasets.FashionMNIST(root='../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
#这里的shuffle(bool, optional):在每个epoch开始的时候,对数据进行重新打乱,就是重新分组
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
#对于测试集来说不需要进行从新分组(这里好像不是必须的,也可以试试每次测试分组,有意义吗?)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Convolutional neural network (two convolutional layers)
#定义一个卷积类,这里需要继承nn.Module,它是专门为神经网络设计的模块化接口
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
#调用父类的初始化函数
super(ConvNet, self).__init__()
#一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
self.layer1 = nn.Sequential(
#二维卷积层,输入通道数1,输出通道数16(相当于有16个filter,也就是16个卷积核),卷积核大小为5*5,步长为1,零填充2圈
#经过计算,可以得到卷积输出的图像的大小和输入的图像大小是等大小的,但是深度不一样,为28*28*16(16为深度),因为这里的padding抵消了卷积的缩小
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
#BatchNorm2d是卷积网络中防止梯度消失或爆炸的函数,参数是卷积的输出通道数
nn.BatchNorm2d(16),
#激活函数
nn.ReLU(),
#二维最大池化,核的大小为2,步长为2
#这样输出的图片大小就是14*14*16(16为深度)
nn.MaxPool2d(kernel_size=2, stride=2))
#两层的卷积网络,具体含义和上面相同
self.layer2 = nn.Sequential(
#这里大小也没有变化,输出依然和输出的大小相同,深度为32,所以图像为14*14*32
#但是这里的卷积核的数量是32,和输出通道数相同。
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
#下面经过池化后输出就会变成7*7*32
nn.MaxPool2d(kernel_size=2, stride=2))
#对输入数据做线性变换,第一个参数是每个输入样本的大小:7*7*32;第二个参数是输出样本的大小,这里是10,正好代表10个数,相当于类别
#第三个参数为bias(偏差),默认为True。如果为False,那么这层将不会学习偏置。
self.fc = nn.Linear(7*7*32, num_classes)
#定义了每次执行的 计算步骤。 在所有的子类中都需要重写forward函数。
#
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
model = ConvNet(num_classes).to(device)
# Loss and optimizer
#损失函数,
criterion = nn.CrossEntropyLoss()
#优化函数
#params (iterable)第一个参数:待优化参数的iterable或者是定义了参数组的dict
#lr (float, 可选) – 学习率(默认:1e-3)同样也称为学习率或步长因子,它控制了权重的更新比率(如 0.001)。
#较大的值(如 0.3)在学习率更新前会有更快的初始学习,而较小的值(如 1.0E-5)会令训练收敛到更好的性能。
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
#total_step是每一轮的测试次数,这里就是60000/100=600次
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
#model(images)等价于module.forward(images)
outputs = model(images)
#根据输出的结果和标签对比,计算loss
loss = criterion(outputs, labels)
# Backward and optimize
#根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;
#但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了
#将梯度初始化为零
optimizer.zero_grad()
#这里是使用反向传播计算梯度值
loss.backward()
#在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
手写数字识别 卷积神经网络 Pytorch框架实现的更多相关文章
- 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
- 识别手写数字增强版100% - pytorch从入门到入道(一)
手写数字识别,神经网络领域的“hello world”例子,通过pytorch一步步构建,通过训练与调整,达到“100%”准确率 1.快速开始 1.1 定义神经网络类,继承torch.nn.Modul ...
- 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
通过: 手写数字识别 ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别 ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- TensorFlow 卷积神经网络手写数字识别数据集介绍
欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 手写数字识别 接下来将会以 MNIST 数据集为例,使用卷积层和池 ...
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...
随机推荐
- 关系型数据库(四),引擎MyISAM和InnoDB
目录 1.MyISAM和InnoDB关于锁方面的区别是什么 2.MYSQL的两个常用存储引擎 3.MyISAM应用场景 4.InnoDB适合场景 四.引擎MyISAM和InnoDB 1.MyISAM和 ...
- 微信小程序开发整理
具体介绍包含以下内容: 1.文件结构 2.组件 4.API 4.工具 5.问题
- CentOS 7 各个版本的区别
CentOS 7 各个版本的区别 2017年07月04日 10:44:37 程诺 阅读数 52029 版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.n ...
- java web项目启动时浏览器路径不用输入项目名称方法
http://blog.csdn.net/qq542045215/article/details/44923851
- windows上批量杀指定进程
Taskkill 结束一个或多个任务或进程.可以根据进程 ID 或图像名来结束进程. 语法 taskkill [/s Computer] [/u Domain\User [/p Password]]] ...
- 数据加密之RSA
特别提示:本人博客部分有参考网络其他博客,但均是本人亲手编写过并验证通过.如发现博客有错误,请及时提出以免误导其他人,谢谢!欢迎转载,但记得标明文章出处:http://www.cnblogs.com/ ...
- 阿里镜像源配置yum
通过more /etc/*release* 查看系统版本 (需要下载对应的系统版本) mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/Cen ...
- 转载 Golang []byte与string转换的一个误区
Golang []byte与string转换的一个误区 https://www.oyohyee.com/post/Note/golang_byte_to_string/ 2019-08-10 23:4 ...
- Serializable 和 Parcelable 的区别?
1.在使用内存的时候,Parcelable 类比 Serializable 性能高,所以推荐使用 Parcelable 类.2.Serializable 在序列化的时候会产生大量的临时变量,从而引起频 ...
- netfilter/iptables 防火墙
目录 文章目录 目录 iptables 与 netfilter 工作机制 规则(Rules) 链(chain) 表(tables) 网络数据包通过 iptables 的过程 总结链.表和规则的关系 i ...