#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03' import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data train_data = torchvision.datasets.MNIST(
'./mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'./mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64) class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
) def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out model = Net()
print(model) optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss() for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data)))) # evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))

Pytorch入门实例:mnist分类训练的更多相关文章

  1. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. 超简单!pytorch入门教程(五):训练和测试CNN

    我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...

  4. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. Pytorch入门中 —— 搭建网络模型

    本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...

  6. Pytorch入门之VAE

    关于自编码器的原理见另一篇博客 : 编码器AE & VAE 这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现. 1. 稀疏编码 首先介绍一下“稀疏编码 ...

  7. pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...

  8. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

  9. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

随机推荐

  1. 027 storm面试小题

    1.大纲 Storm工作原理是什么? 流的模式是什么?默认是什么? 对于mapreduce如何理解? Storm的特点和特性是什么? Storm组件有哪些? 2.Storm工作原理是什么? 相对于ha ...

  2. 微信小程序支付遇到的坑

    1,微信公众号支付和微信小程序支付有差异 微信公众号:可以直接跳转走h5的微信支付 微信小程序:在测试环境.沙箱环境使用微信公众号的跳转支付没有问题,在线上存在支付异常 最后商讨的解决方法 openi ...

  3. Java 扫描实现 Ioc 动态注入,过滤器根据访问url调用自定义注解标记的类及其方法

    扫描实现 Ioc 动态注入 参考: http://www.private-blog.com/2017/11/16/java-%e6%89%ab%e6%8f%8f%e5%ae%9e%e7%8e%b0-i ...

  4. Objective-C中整数与字符串的相互转换

    2013/4/15整理: 将整数转换成字符串 Convert Integer to NSString: 方法一: int Value = 112233; NSString *ValueString = ...

  5. 图形上下文导论(Introduction to SWT Graphics)zz

    图形上下文导论(Introduction to SWT Graphics) 摘要: org.eclipse.swt.graphics包(package),包含了管理图形资源的类.只要实现了org.ec ...

  6. 2019-3-22c# TextBox只允许输入数字,禁用右键粘贴,允许Ctrl+v粘贴数字

    TextBox 禁止复制粘贴 ShortcutsEnabled =false TextBox只允许输入数字,最大长度为10 //TextBox.ShortcutsEnabled为false 禁止右键和 ...

  7. CTSC2018 被屠记

    占坑 day-1 Gery 和 lll 在火车上谈笑风生 day0 上午 Gery:"我要吊打全场" 下午 Gery 忘带杯子了. Gery:"我过两天碰杯了就可以喝到水 ...

  8. Y1S002 xshell脚本编写示意

    SecureCRT可以自己录制脚本,非常的方便:但是考虑到CRT收费,所以不计划把CRT作为使用的终端. vbs脚本(test.vbs): Sub main xsh.Screen.Synchronou ...

  9. pip install

    pip install <包名> 或 pip install -r requirements.txt 通过使用 == >= <= > < 来指定版本,不写则安装最新 ...

  10. mysql数据库内容相关操作

    第一:介绍 mysql数据内容的操作主要是: INSERT实现数据的插入 UPDATE实现数据的更新 DLETE实现数据的删除 SELECT实现数据的查询. 第二:增(insert) 1.插入完整的数 ...