在net.py里面构造网络,网络的结构为输入为28*28,第一层隐藏层的输出为300, 第二层输出的输出为100, 最后一层的输出层为10,

net.py

import torch
from torch import nn class Batch_Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Batch_Net, self).__init__()
self.layer_1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
self.layer_2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
self.output = nn.Sequential(nn.Linear(n_hidden_2, out_dim)) def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.output(x)
return x

main.py 进行网络的训练

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms import net batch_size = 128 # 每一个batch_size的大小
learning_rate = 1e-2 # 学习率的大小
num_epoches = 20 # 迭代的epoch值
# 表示data将数据变成0, 1之间,0.5, 0.5表示减去均值处以标准差
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 表示均值和标准差
# 获得训练集的数据
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
# 获得测试集的数据
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)
# 获得训练集的可迭代队列
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 获得测试集的可迭代队列
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 构造模型的网络
model = net.Batch_Net(28*28, 300, 100, 10)
if torch.cuda.is_available(): # 如果有cuda就将模型放在GPU上
model.cuda() criterion = nn.CrossEntropyLoss() # 构造交叉损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 构造模型的优化器 for epoch in range(num_epoches): # 迭代的epoch
train_loss = 0 # 训练的损失值
test_loss = 0 # 测试的损失值
eval_acc = 0 # 测试集的准确率
for data in train_loader: # 获得一个batch的样本
img, label = data # 获得图片和标签
img = img.view(img.size(0), -1) # 将图片进行img的转换
if torch.cuda.is_available(): # 如果存在torch
img = Variable(img).cuda() # 将图片放在torch上
label = Variable(label).cuda() # 将标签放在torch上
else:
img = Variable(img) # 构造img的变量
label = Variable(label)
optimizer.zero_grad() # 消除optimizer的梯度
out = model.forward(img) # 进行前向传播
loss = criterion(out, label) # 计算损失值
loss.backward() # 进行损失值的后向传播
optimizer.step() # 进行优化器的优化
train_loss += loss.data #
for data in test_loader:
img, label = data
img = img.view(img.size(0), -1)
if torch.cuda.is_available():
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model.forward(img)
loss = criterion(out, label)
test_loss += loss.data
top_p, top_class = out.topk(1, dim=1) # 获得输出的每一个样本的最大损失
equals = top_class == label.view(*top_class.shape) # 判断两组样本的标签是否相等
accuracy = torch.mean(equals.type(torch.FloatTensor)) # 计算准确率
eval_acc += accuracy
print('train_loss{:.6f}, test_loss{:.6f}, Acc:{:.6f}'.format(train_loss / len(train_loader), test_loss / len(test_loader), eval_acc / len(test_loader)))

pytorch-mnist神经网络训练的更多相关文章

  1. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. PyTorch Tutorials 4 训练一个分类器

    %matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...

  4. Pytorch多GPU训练

    Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...

  5. 使用pytorch构建神经网络的流程以及一些问题

    使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...

  6. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  7. 使用PyTorch构建神经网络以及反向传播计算

    使用PyTorch构建神经网络以及反向传播计算 前一段时间南京出现了疫情,大概原因是因为境外飞机清洁处理不恰当,导致清理人员感染.话说国外一天不消停,国内就得一直严防死守.沈阳出现了一例感染人员,我在 ...

  8. 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像

    摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...

  9. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  10. 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel

    一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...

随机推荐

  1. debian上安装tmux

    1.安装ncurses库 1.1.获取源码 wget https://invisible-island.net/datafiles/release/ncurses.tar.gz tar xvf ncu ...

  2. go爬虫之爬取豆瓣电影

    go爬取豆瓣电影 好久没使用go语言做个项目了,上午闲来无事花了点时间使用golang来爬取豆瓣top电影,这里我没有用colly框架而是自己设计简单流程.mark一下 思路 定义两个channel, ...

  3. Java错误和异常解析

    Java错误和异常解析 错误和异常 在Java中, 根据错误性质将运行错误分为两类: 错误和异常. 在Java程序的执行过程中, 如果出现了异常事件, 就会生成一个异常对象. 生成的异常对象将传递Ja ...

  4. 用idea操作svn

    使用SVN前提必须安装好服务端和客户端,或者知道服务端的url才能对服务器中的文件进行操作. 服务端:SVN service 客户端:TortoiseSVN 提交 第一步:确认SVN 服务器是否开启 ...

  5. mongodb命令---创建数据库,插入文档,更新文档记录

    创建数据库----基本就是使用隐式创建 例如 use 你定义的数据库名, use dingsmongo  如果你使用的是studio 3T软件,那直接选中右侧的地址栏点击右键选择Add Databas ...

  6. bzoj1458: 士兵占领(最大流)

    题目描述 有一个M * N的棋盘,有的格子是障碍.现在你要选择一些格子来放置一些士兵,一个格子里最多可以放置一个士兵,障碍格里不能放置士兵.我们称这些士兵占领了整个棋盘当满足第i行至少放置了Li个士兵 ...

  7. BZOJ 2178: 圆的面积并 (辛普森积分)

    code #include <set> #include <cmath> #include <cstdio> #include <cstring> #i ...

  8. 最大的矩形(CCF)

    问题描述 在横轴上放了n个相邻的矩形,每个矩形的宽度是1,而第i(1 ≤ i ≤ n)个矩形的高度是hi.这n个矩形构成了一个直方图.例如,下图中六个矩形的高度就分别是3, 1, 6, 5, 2, 3 ...

  9. siblings([expr])

    siblings([expr]) 概述 取得一个包含匹配的元素集合中每一个元素的所有唯一同辈元素的元素集合.可以用可选的表达式进行筛选.大理石平台维修   参数 exprStringV1.0 用于筛选 ...

  10. 用php把excel数据导入数据库

    PHPExcel是一个PHP类库,用来帮助我们简单.高效实现从Excel读取Excel的数据和导出数据到Excel. 先下载PHPExcel类库· 读取文件源码: <?php header(&q ...