Pytorch搭建卷积神经网络用于MNIST分类
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from torch.nn import functional as F EPOCH = 1000
BATCH_SIZE = 128
LR = 0.001
DOWNLOAD_MNIST = False train_data = datasets.MNIST(
root='./mnist',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
#plot one example
print(train_data.train_data.size())
print(train_data.train_labels.size())
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) test_data = datasets.MNIST(
root='./mnist',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) x, label = iter(test_loader).next() #这个iter能把一个batch_size提取出来
print("x:",x.shape, 'label:',label.shape) class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=10, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(10, 20, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out1 = nn.Linear(20 * 7 * 7, 512) # fully connected layer, output 10 classes
self.out2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out1(x)
output = F.relu(output)
output = self.out2(output)
return output # return x for visualization DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = CNN().to(DEVICE)
print(cnn) # net architecture
optimizer = optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
criteon = nn.CrossEntropyLoss().to(DEVICE) # the target label is not one-hotted def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(batch_idx+1)%30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += torch.eq(pred, target).float().sum().item()
test_loss += criterion(output, target) # test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
# pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
# correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) # training and testing
for epoch in range(EPOCH):
train(cnn, DEVICE, train_loader, optimizer, epoch)
test(cnn, DEVICE, test_loader)
最后能得到99%的准确率
Pytorch搭建卷积神经网络用于MNIST分类的更多相关文章
- Tensorflow框架初尝试————搭建卷积神经网络做MNIST问题
Tensorflow是一个非常好用的deep learning框架 学完了cs231n,大概就可以写一个CNN做一下MNIST了 tensorflow具体原理可以参见它的官方文档 然后CNN的原理可以 ...
- Pytorch搭建简单神经网络 Task2
1>建立数据集(并绘制图像) # -*- coding: utf-8 -*- #demo.py import torch import torch.nn.functional as F # 主要 ...
- TensorFlow——CNN卷积神经网络处理Mnist数据集
CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...
- 3层-CNN卷积神经网络预测MNIST数字
3层-CNN卷积神经网络预测MNIST数字 本文创建一个简单的三层卷积网络来预测 MNIST 数字.这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成. MNIST 由 ...
- 使用MXNet远程编写卷积神经网络用于多标签分类
最近试试深度学习能做点什么事情.MXNet是一个与Tensorflow类似的开源深度学习框架,在GPU显存利用率上效率高,比起Tensorflow显著节约显存,并且天生支持分布式深度学习,单机多卡.多 ...
- TensorFlow系列专题(十四): 手把手带你搭建卷积神经网络实现冰山图像分类
目录: 冰山图片识别背景 数据介绍 数据预处理 模型搭建 结果分析 总结 一.冰山图片识别背景 这里我们要解决的任务是来自于Kaggle上的一道赛题(https://www.kaggle.com/c/ ...
- 动手学习Pytorch(6)--卷积神经网络基础
卷积神经网络基础 本节我们介绍卷积神经网络的基础概念,主要是卷积层和池化层,并解释填充.步幅.输入通道和输出通道的含义. 二维卷积层 本节介绍的是最常见的二维卷积层,常用于处理图像数据. 二维 ...
- 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别
这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...
- 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)
1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...
随机推荐
- SCC统计
Kosoraju SCC总数及记录SCC所需要的最少边情况 #include<cstdio> ; ; ][N], nxt[][N], v[][N], ed, q[N], t, vis[N] ...
- Linux查看磁盘空间大小
1. Ubuntu 查看磁盘空间大小命令 df -h Df命令是linux系统以磁盘分区为单位查看文件系统,可以加上参数查看磁盘剩余空间信息, 命令格式: df -hl 显示格式为: 文件系统 容 ...
- Python之网路编程利用threading模块开线程
一多线程的概念介绍 threading模块介绍 threading模块和multiprocessing模块在使用层面,有很大的相似性. 二.开启多线程的两种方式 1 1.创建线程的开销比创建进程的开销 ...
- vs2017 mvc 启动时经常出现调用的目标发生异常
1.vs 2017 调试web 程序时老是出现调用的目标发生异常 本人眼拙,基本上看了网站说的一些方法,设置环境变量是无效的,只有一个办法,卸载重装. 1.0 卸载过程 打开计算机-卸载或更改软件- ...
- 【转】encodeURI和decodeURI方法
为什么要两次调用encodeURI来解决乱码问题 https://blog.csdn.net/howlaa/article/details/12834595 请注意 encodeURIComponen ...
- day_08 字符编码乱码处理
Python3默认编码是unicode:而Python2是ASCII码.Windows环境默认是gbk编码. 常见编码错误原因: 1. Python解释器的默认编码 2. Python源文件文件编码 ...
- MySQL——复制(Replication)
1.复制概述 1.1.复制解决的问题数据复制技术有以下一些特点:(1) 数据分布(2) 负载平衡(load balancing)(3) 备份(4) 高可用性(high avai ...
- 对DOMContentLoaded的研究 -----------------------引用
1. 什么是 DOMContentLoaded.打开 Chrome DevTools,切到 Network 面板,重新加载网页,得到如下截图: 标记 1 指向的蓝线以及标记 2 指向的蓝色字 “ ...
- js对对象增加删除属性
1.首选创建一个对象 var a={}; 2.然后对这个对象赋值 a.name='zhouy';console.log(a);var age="age";a[age]=26;con ...
- 部署zabbix 4.0 + grafana
不完整,仅供参考 Zabbix+grafana监控部署 基本环境 系统: CentOS Linux release 7.3.1611 Zabbix—server: Zabbix_agent: N ...