pytorch-MNIST数据模型测试
用pytorch搭建一个DNN网络,主要目的是熟悉pytorch的使用
"""
test Function
""" import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms class simpleNet(nn.Module):
''' define the 3 layers Network'''
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(simpleNet, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Activation_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Activation_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1) ,nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(n_hidden_1,n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(n_hidden_2, out_dim)
) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x batch_size = 64
learning_rate = 1e-2
num_epochs = 20 data_tf = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
) train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) model = Batch_Net(28*28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # Training
epoch = 0
for data in train_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img)
label = Variable(label)
out = model(img)
loss = criterion(out, label)
print_loss = loss.data.item() optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch += 1
if epoch % 50 == 0:
print('epoch:{}, loss:{:.4f}'.format(epoch, loss.data.item())) # Evalue
model.eval() # turn the model to test pattern, do some as dropout, batchNormalization
eval_loss = 0
eval_acc = 0
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
img = Variable(img) # 前向传播不需要保留缓存,释放掉内存,节约内存空间
label = Variable(label)
out = model(img)
loss = criterion(out, label) eval_loss += loss.data * label.size(0)
_, pred = torch.max(out, 1) # 返回每一行中最大值和对应的索引
s = (pred == label)
num_correct = (pred == label).sum()
eval_acc += num_correct.data.item()
print('Test Loss:{:6f}, Acc:{:.6f}'.format(eval_loss/len(test_dataset), eval_acc/len(test_dataset)))
pytorch-MNIST数据模型测试的更多相关文章
- Tensorflow MNIST 数据集测试代码入门
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50614444 测试代码已上传至GitH ...
- 深入MNIST code测试
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471 依照教程:深入MNIST ...
- highway network及mnist数据集测试
先说结论:没经过仔细调参,打不开论文所说代码链接(fq也没打开),结果和普通卷积网络比较没有优势.反倒是BN对网络起着非常重要的作用,达到了99.17%的测试精度(训练轮数还没到过拟合). 论文为&l ...
- mxnet卷积神经网络训练MNIST数据集测试
mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...
- 如何使用Pytorch迅速实现Mnist数据及分类器
一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...
- Caffe初试(二)windows下的cafee训练和测试mnist数据集
一.mnist数据集 mnist是一个手写数字数据库,由Google实验室的Corinna Cortes和纽约大学柯朗研究院的Yann LeCun等人建立,它有60000个训练样本集和10000个测试 ...
- 使用xshell+xmanager+pycharm搭建pytorch远程调试开发环境
1. 相关软件版本 xshell: xmanager: pycharm: pycharm破解服务器:https://jetlicense.nss.im/ 2. 将相应的软件安装(pojie好) a&g ...
- Pytorch学习之源码理解:pytorch/examples/mnists
Pytorch学习之源码理解:pytorch/examples/mnists from __future__ import print_function import argparse import ...
- [源码解析] PyTorch 分布式(4)------分布式应用基础概念
[源码解析] PyTorch 分布式(4)------分布式应用基础概念 目录 [源码解析] PyTorch 分布式(4)------分布式应用基础概念 0x00 摘要 0x01 基本概念 0x02 ...
随机推荐
- PHP异步扩展Swoole笔记(2)
dispatch_mode, 数据包分发策略 可以选择7种类型,默认为21,轮循模式,收到会轮循分配给每一个Worker进程2,固定模式,根据连接的文件描述符分配Worker.这样可以保证同一个连接发 ...
- Ubuntu 16.04常用快捷键
注意:在Linux下Win键就是Super键 启动器 Win(长按) 打开启动器,显示快捷键 Win + Tab 通过启动器切换应用程序 Win + 1到9 与点击启动器上的图标效果一样 Win + ...
- 【VS2019】F12跳转到源码,关闭浏览器不停止项目【转】
[VS2019]F12跳转到源码 1.工具->选项 2.文本编辑器->C#->高级->勾选支持导航到反编译源码 3.关闭浏览器不停止项目
- Unity应用架构设计(5)——ViewModel之间如何共享数据
对于客户端应用程序而言,单页应用程序(Single Page Application)是最常见的表现形式.有经验的开发人员往往会把一个View分解多个SubView.那么,如何在多个SubView之间 ...
- 中文分词工具thulac4j发布
1. 介绍 thulac4j是THULAC的Java 8工程化实现,具有分词速度快.准.强的特点:支持 自定义词典 繁体转简体 停用词过滤 若想在项目中使用thulac4j,可添加依赖: <de ...
- 小型互联网公司的IT系统建设思路
最近一些想创业的一帮兄弟来问我,准备借助互联网的翅膀,做某某事情,并想尽快的做出一个系统平台. 我给的思路,分6个步骤: 需求-> 灵感设计 ->实现 ->迭代改进 ->成 ...
- 分布式Id教程
转自:https://baijiahao.baidu.com/s?id=1584913615817222458&wfr=spider&for=pc 一,题记 所有的业务系统,都有生成I ...
- Java编程的逻辑 (94) - 组合式异步编程
本系列文章经补充和完善,已修订整理成书<Java编程的逻辑>,由机械工业出版社华章分社出版,于2018年1月上市热销,读者好评如潮!各大网店和书店有售,欢迎购买,京东自营链接:http: ...
- Activity的Launch mode详解,A B C D的singleTask模式
本文参考了此文http://hi.baidu.com/amauri3389/blog/item/a54475c2a4b2f040b219a86a.html 另附 android task与back s ...
- C语言 sscanf用法详解
/* sscanf用法详解 */ #include <stdio.h> /* sscanf头文件 */ #include <stdlib.h> #include <str ...