Pytorch学习之源码理解:pytorch/examples/mnists

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()#调用父类的构造方法
self.conv1 = nn.Conv2d(1, 32, 3, 1)#输入1个channel,输出32个channels,kernel_size=3,stride(步长)=1
self.conv2 = nn.Conv2d(32, 64, 3, 1)#再变成64channels
self.dropout1 = nn.Dropout2d(0.25)#以0.25的概率dropout
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)#9216->128
self.fc2 = nn.Linear(128, 10)
#定义网络各层
def forward(self, x):
x = self.conv1(x)
#线性整流函数(Rectified Linear Unit, ReLU)是一个激活函数,这是当成一层了
#卷积神经网络中,若不采用非线性激活,会导致神经网络只能拟合线性可分的数据,因此通常会在卷积操作后,添加非线性激活单元,其中包括logistic-sigmoid、tanh-sigmoid、ReLU等。
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output def train(args, model, device, train_loader, optimizer, epoch):
model.train()
#这是两种模式
#model.train() :启用 BatchNormalization 和 Dropout
#model.eval() :不启用 BatchNormalization 和 Dropout
#model.eval(),pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。
# 不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大;在模型测试阶段使用
#trainloader对每一个batch加了id
for batch_idx, (data, target) in enumerate(train_loader):
#读入数据到device中,之后就用新的变量表示就可,对程序不影响(物理层和应用层)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()#初始化优化器参数
output = model(data)
loss = F.nll_loss(output, target)#计算loss
loss.backward()#反向传播,计算梯度和
optimizer.step()#调整参数。
#上面的方法都是共享一个参数空间的,所以不需要传递参数
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) def main():
# Training settings
#都是可选参数,是为了调参用的
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')#加上参数描述,在--help中输出
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()#获取参数,从这里就可以开始调用这些参数了。没有输入也没有设置默认值的就是null,用在布尔表达式里面也可以表示false
use_cuda = not args.no_cuda and torch.cuda.is_available()#有cuda并且没设置参数说不用才用cuda torch.manual_seed(args.seed)#设置随机种子,以便于生成随机数 device = torch.device("cuda" if use_cuda else "cpu")#决定用cpu还是GPU kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
#载入训练集
train_loader = torch.utils.data.DataLoader(
#torchvision下的datasets模块,如果没发现本地有这个包就下载
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),#输出tensor类型
transforms.Normalize((0.1307,), (0.3081,))#do normalize
])),
batch_size=args.batch_size, shuffle=True, **kwargs)#一次读多少
#载入测试集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
#将模型读入device
model = Net().to(device)
#设置优化器,这里使用的是Adagrad优化方法(Adaptive Gradient)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
#等间隔调整学习率 StepLR
#等间隔调整学习率,调整倍数为 gamma 倍,调整间隔为 step_size。间隔单位是step。需要注意的是, step 通常是指 epoch,不要弄成 iteration 了。
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):#迭代次数
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()#每次迭代之后调整学习率 if args.save_model:#保存模型
torch.save(model.state_dict(), "mnist_cnn.pt") if __name__ == '__main__':
main()

Pytorch学习之源码理解:pytorch/examples/mnists的更多相关文章

  1. [源码解析] PyTorch 分布式(1)------历史和概述

    [源码解析] PyTorch 分布式(1)------历史和概述 目录 [源码解析] PyTorch 分布式(1)------历史和概述 0x00 摘要 0x01 PyTorch分布式的历史 1.1 ...

  2. [源码解析] PyTorch 如何使用GPU

    [源码解析] PyTorch 如何使用GPU 目录 [源码解析] PyTorch 如何使用GPU 0x00 摘要 0x01 问题 0x02 移动模型到GPU 2.1 cuda 操作 2.2 Modul ...

  3. [源码解析] PyTorch 分布式(15) --- 使用分布式 RPC 框架实现参数服务器

    [源码解析] PyTorch 分布式(15) --- 使用分布式 RPC 框架实现参数服务器 目录 [源码解析] PyTorch 分布式(15) --- 使用分布式 RPC 框架实现参数服务器 0x0 ...

  4. [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC

    [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC 目录 [源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC 0x00 摘要 0x0 ...

  5. [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行

    [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 目录 [源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行 0x00 摘要 0x0 ...

  6. [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(3)---代理 目录 [源码解析] PyTorch 分布式之弹性训练(3)---代理 0x00 摘要 0x01 总体背景 1.1 功能分离 1.2 Re ...

  7. [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...

  8. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  9. [源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统

    [源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统 目录 [源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统 0x00 摘要 0x01 分割小批次 ...

随机推荐

  1. Python 自学笔记(四)

    1.for...in...循环语句 1-1.遍历列表 1-2.遍历字典 1-2-1.遍历字典的键和值 1-2-2.遍历字典的键值(一) 1-2-3.遍历字典的键值(二) 1-2-4.遍历字典的值 1- ...

  2. Vue UI组件库

    1. iView UI组件库  iView官网:https://www.iviewui.com/ 2.Vux UI组件库   Vux官网:https://vux.li/ 3.Element UI组件库 ...

  3. docker swarm和 k8s对比

    Swarm的优势:swarm API兼容docker API,使得swarm 学习成本低,同时架构简单,部署运维成本较低.Swarm的劣势:同样是因为API兼容,无法提供集群的更加精细的管理.在网络方 ...

  4. Canvas-基本用法

    Canvas教程-MDN HTML 5 Canvas 参考手册 <canvas>是一个可以使用脚本(通常为JavaScript)来绘制图形的 HTML 元素.例如,它可以用于绘制图表.制作 ...

  5. 浅谈smarty模板的mvc框架

    最近接触了一个大项目,php做的后台管理,融合了smarty模板+mvc框架+phpcms内容管理,,,这个项目简直就是php的精华,于是小编大哥对项目小女子产生了兴趣,打算一点一点把她征服.现在小吃 ...

  6. PAT 甲级 1013 Battle Over Cities (25 分)(图的遍历,统计强连通分量个数,bfs,一遍就ac啦)

    1013 Battle Over Cities (25 分)   It is vitally important to have all the cities connected by highway ...

  7. 怎样获取java新IO的Path文件大小

    import org.junit.Test; import java.io.IOException; import java.nio.file.Files; import java.nio.file. ...

  8. python函数,定义,参数,返回值

    python中可以将某些具备一定功能的代码写成一个函数,通过函数可以在一定程度上减少代码的冗余,节约书写代码的时间.因为有一些代码实现的功能我们可能会在很多地方用到. 1.函数的声明与定义 通过def ...

  9. 【FICO系列】SAP FI验证故障排除(调试)

    公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[FICO系列]SAP FI验证故障排除(调试) ...

  10. 单例Bean注册表接口SingletonBeanRegistry

    Github: SingletonBeanRegistry.java SingletonBeanRegistry package org.springframework.beans.factory.c ...