pytorch-mnist神经网络训练
在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,
net.py
- import torch
- from torch import nn
- class Batch_Net(nn.Module):
- def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
- super(Batch_Net, self).__init__()
- self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
- self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
- self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
- def forward(self, x):
- x = self.layer_1(x)
- x = self.layer_2(x)
- x = self.output(x)
- return x
main.py 进行网络的训练
- import torch
- from torch import nn, optim
- from torch.autograd import Variable
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
- import net
- batch_size = 128 # 每一个batch_size的大小
- learning_rate = 1e-2 # 学习率的大小
- num_epoches = 20 # 迭代的epoch值
- # 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
- data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
- # 获得训练集的数据
- train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
- # 获得测试集的数据
- test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
- # 获得训练集的可迭代队列
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- # 获得测试集的可迭代队列
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
- # 构造模型的网络
- model = net.Batch_Net(28*28, 300, 100, 10)
- if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
- model.cuda()
- criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
- optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器
- for epoch in range(num_epoches): # 迭代的epoch
- train_loss = 0 # 训练的损失值
- test_loss = 0 # 测试的损失值
- eval_acc = 0 # 测试集的准确率
- for data in train_loader: # 获得一个batch的样本
- img, label = data # 获得图片和标签
- img = img.view(img.size(0), -1) # 将图片进行img的转换
- if torch.cuda.is_available(): # 如果存在torch
- img = Variable(img).cuda() # 将图片放在torch上
- label = Variable(label).cuda() # 将标签放在torch上
- else:
- img = Variable(img) # 构造img的变量
- label = Variable(label)
- optimizer.zero_grad() # 消除optimizer的梯度
- out = model.forward(img) # 进行前向传播
- loss = criterion(out, label) # 计算损失值
- loss.backward() # 进行损失值的后向传播
- optimizer.step() # 进行优化器的优化
- train_loss += loss.data #
- for data in test_loader:
- img, label = data
- img = img.view(img.size(0), -1)
- if torch.cuda.is_available():
- img = Variable(img, volatile=True).cuda()
- label = Variable(label, volatile=True).cuda()
- else:
- img = Variable(img, volatile=True)
- label = Variable(label, volatile=True)
- out = model.forward(img)
- loss = criterion(out, label)
- test_loss += loss.data
- top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
- equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
- accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
- eval_acc += accuracy
- print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))
pytorch-mnist神经网络训练的更多相关文章
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)
Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...
- PyTorch Tutorials 4 训练一个分类器
%matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...
- Pytorch多GPU训练
Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...
- 使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...
- Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)
基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html 摘要 在前面的博文中,我详细介绍了Caf ...
- 使用PyTorch构建神经网络以及反向传播计算
使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...
- 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像
摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...
- 神经网络训练中的Tricks之高效BP(反向传播算法)
神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...
- 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel
一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...
随机推荐
- Java并发编程之线程池及示例
1.Executor 线程池顶级接口.定义方法,void execute(Runnable).方法是用于处理任务的一个服务方法.调用者提供Runnable 接口的实现,线程池通过线程执行这个 Runn ...
- 编译luacheck Linux版
最近在写Visual Studio Code的Lua插件,需要把luacheck集成进去.但是luacheck默认只提供了win32版本,见https://github.com/mpeterv/lua ...
- redis __详解 (转载自作者:孤独烟 出处: http://rjzheng.cnblogs.com/)
https://www.cnblogs.com/rjzheng/p/9096228.html [原创]分布式之redis复习精讲 引言 为什么写这篇文章? 博主的<分布式之消息队列复习精讲> ...
- 如何自动运行loadrunner脚本
问题背景 在凌晨之后,自然流量比较低,无需人值守的情况自动运行loadruner脚本. 实现思路 windows定时任务+BAT脚本 BAT脚本: SET M_ROOT=C:\Program File ...
- 多线程--volatile
在解释volatile关键字之前,先说说java的指令重排以及代码的执行顺序. 指令重排: public void sum(){ int x = 1; int y = 2; int x = x + 1 ...
- linux数码管驱动程序和应用程序
- day01_人类社会货币的演变
1.货币的自然演变 1.1:从实物货币(贝壳.金银等一般等价物的稀有性等价于被交换物品的价值)---纸质货币(国家信用背书,使得一文不值的纸币可以兑换价值百元的商品)---记账货币(微信.二维码.银行 ...
- java--mybatis的实现原理
动态代理? 需要调试下,看下源码,再研究下……
- zznu-2183: 口袋魔方
大致题意: 题目描述 口袋魔方又称为迷你魔方,通俗的来讲就是二阶魔方,只有八个角块的魔方,如图所示. 二阶魔方8个角块的位置均可进行任意互换(!种状态),如果以一个角块不动作为参考角块,其他7个 角块 ...
- git 忽略部分文件
忽略: git update-index --assume-unchanged .mymetadata 取消忽略: git update-index --no-assume-unchanged