第一次,调了很久。它本来已经很OK了,同时适用CPU和GPU,且可正常运行的。

为了用于性能测试,主要改了三点:

一,每一批次显示处理时间。

二,本地加载测试数据。

三,兼容LINUX和WIN

本地加载测试数据时,要注意是用将两个pt文件,放在processed目录下,raw目录不要即可。

训练数据的定义目录是在当前目录 data/MNIST/processed目录下。

我自己弄了个下载:

http://u.163.com/2FUm6N1L  提取码: XJpmqUoR

只能下载20次,过了可在此留言。

import os
import timeit
import torch                     # pytorch 最基本模块
import torch.nn as nn            # pytorch中最重要的模块,封装了神经网络相关的函数
import torch.nn.functional as F  # 提供了一些常用的函数,如softmax
import torch.optim as optim      # 优化模块,封装了求解模型的一些优化器,如Adam SGD
from torch.optim import lr_scheduler # 学习率调整器,在训练过程中合理变动学习率
from torchvision import transforms  #pytorch 视觉库中提供了一些数据变换的接口
from torchvision import datasets  #pytorch 视觉库提供了加载数据集的接口

DATA_DIR = os.path.join(os.getcwd(),"data")
# 预设网络超参数 (所谓超参数就是可以人为设定的参数

BATCH_SIZE= 64 # 由于使用批量训练的方法,需要定义每批的训练的样本数目

EPOCHS=3      # 总共训练迭代的次数

# 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

learning_rate = 0.1  # 设定初始的学习率

# 加载训练集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(DATA_DIR, train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5,), std=(0.5,)) # 数据规范化到正态分布
                    ])),
    batch_size=BATCH_SIZE, shuffle=True) # 指明批量大小,打乱,这是处于后续训练的需要。

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(DATA_DIR, train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ])),
    batch_size=BATCH_SIZE, shuffle=True)

# 设计模型
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # 提取特征层
        self.features = nn.Sequential(
            # 卷积层
            # 输入图像通道为 1,因为我们使用的是黑白图,单通道的
            # 输出通道为32(代表使用32个卷积核),一个卷积核产生一个单通道的特征图
            # 卷积核kernel_size的尺寸为 3 * 3,stride 代表每次卷积核的移动像素个数为1
            # padding 填充,为1代表在图像长宽都多了两个像素
            nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size=3, stride=1, padding=1),

            # 批量归一化,跟上一层的out_channels大小相等,以下的通道规律也是必须要对应好的
            nn.BatchNorm2d(num_features = 32),

            # 激活函数,inplace=true代表直接进行运算
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            # 最大池化层
            # kernel_size 为2 * 2的滑动窗口
            # stride为2,表示每次滑动距离为2个像素
            # 经过这一步,图像的大小变为1/4,即 28 * 28 -》 14 * 14
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 14 * 14 -》 7 * 7
        )
         # 分类层
        self.classifier = nn.Sequential(
            # Dropout层
            # p = 0.5 代表该层的每个权重有0.5的可能性为0
            nn.Dropout(p = 0.5),
            # 这里是通道数64 * 图像大小7 * 7,然后输入到512个神经元中
            nn.Linear(64 * 7 * 7, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.5),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        # 经过特征提取层
        x = self.features(x)
        # 输出结果必须展平成一维向量
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x  

# 初始化模型
ConvModel = ConvNet().to(DEVICE)
# 定义交叉熵损失函数
criterion = nn.CrossEntropyLoss().to(DEVICE)
# 定义模型优化器
optimizer = torch.optim.Adam(ConvModel.parameters(), lr = learning_rate)
# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

def train(num_epochs,_model, _device, _train_loader, _optimizer, _lr_scheduler):
    _model.train()
    _lr_scheduler.step()
    for epoch in range(num_epochs):
        start = end = 0
        # 从迭代器抽取图片和标签
        for i, (images, labels) in enumerate(_train_loader):
            if (i + 1) % 100 == 1:
                start = timeit.default_timer()
            samples = images.to(_device)
            labels = labels.to(_device)
            #此时样本是一批图片,在CNN的输入中,我们需要将其变为四维,
            # reshape第一个-1 代表自动计算批量图片的数目n
            # 最后reshape得到的结果就是n张图片,每一张图片都是单通道的28 * 28,得到四维张量
            output = _model(samples.reshape(-1, 1, 28, 28))

            # 计算损失函数值
            loss = criterion(output, labels)

            # 优化器内部参数梯度必须变为0
            optimizer.zero_grad()

            # 损失值后向传播
            loss.backward()

            # 更新模型参数
            optimizer.step()

            if (i + 1) % 100 == 0:
                end = timeit.default_timer()
                print("Epoch:{}/{}, Time:{}s, step:{}, loss:{:.4f}".format(epoch+1, num_epochs, end-start, i + 1, loss.item()))

def test(_test_loader, _model, _device):
    _model.eval() # 设置模型进入预测模式 evaluation
    loss = 0
    correct = 0

    with torch.no_grad(): #如果不需要 backward更新梯度,那么就要禁用梯度计算,减少内存和计算资源浪费。
        for data, target in _test_loader:
            data, target = data.to(_device), target.to(_device)
            output = ConvModel(data.reshape(-1, 1, 28, 28))
            loss += criterion(output, target).item() # 添加损失值
            pred = output.data.max(1, keepdim=True)[1] # 找到概率最大的下标,为输出值
            correct += pred.eq(target.data.view_as(pred)).cpu().sum() # .cpu()是将参数迁移到cpu上来。

    loss /= len(_test_loader.dataset)

    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        loss, correct, len(_test_loader.dataset),
        100. * correct / len(_test_loader.dataset)))

for epoch in range(1, EPOCHS + 1):
    train(epoch, ConvModel, DEVICE, train_loader, optimizer, exp_lr_scheduler)
    test(test_loader,ConvModel, DEVICE)
    test(train_loader,ConvModel, DEVICE)

一套兼容win和Linux的PyTorch训练MNIST的算法代码(CNN)的更多相关文章

  1. php中路径斜杠的应用,兼容win与linux

    更多内容推荐微信公众号,欢迎关注: PHP中斜杠的运用 兼容win和linux 使用常量:DIRECTORY_SEPARATOR如:"www".DIRECTORY_SEPARATO ...

  2. 跨平台设置NODE_ENV(兼容win和linux)

    通过NODE_ENV可以来设置环境变量(默认值为development).一般我们通过检查这个值来分别对开发环境和生产环境下做不同的处理.可以在命令行中通过下面的方式设置这个值: linux & ...

  3. 用Pytorch训练MNIST分类模型

    本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...

  4. Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    我用过的编辑器不少,真不少- 但却没有哪款让我特别心仪的,直到我遇到了 Sublime Text 2 !如果说“神器”是我能给予一款软件最高的评价,那么我很乐意为它封上这么一个称号.它小巧绿色且速度非 ...

  5. [转载]Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    代码编辑器或者文本编辑器,对于程序员来说,就像剑与战士一样,谁都想拥有一把可以随心驾驭且锋利无比的宝剑,而每一位程序员,同样会去追求最适合自己的强大.灵活的编辑器,相信你和我一样,都不会例外. 我用过 ...

  6. Java文件夹操作,判断多级路径是否存在,不存在就创建(包括windows和linux下的路径字符分析),兼容Windows和Linux

    兼容windows和linux. 分析: 在windows下路径有以下表示方式: (标准)D:\test\1.txt (不标准,参考linux)D:/test/1.txt 然后在java中,尤其使用F ...

  7. paip兼容windows与linux的java类根目录路径的方法

    paip兼容windows与linux的java类根目录路径的方法 1.只有 pathx.class.getResource("")或者pathx.class.getResourc ...

  8. redhat 安装配置samba实现win共享linux主机目录

    [转]http://blog.chinaunix.net/uid-26642180-id-3135941.html redhat 安装配置samba实现win共享linux主机目录 2012-03-1 ...

  9. Win和Linux查看端口和杀死进程

    title: Win和Linux查看端口和杀死进程 date: 2017-7-30 tags: null categories: Linux --- 本文介绍Windows和Linux下查看端口和杀死 ...

随机推荐

  1. BatchConfigTool批量配置工具

    海康批量配置工具BatchConfigTool是一款支持设备在线搜索.批量配置参数.批量升级等功能的软件,支持对大批量设备同时进行各参数的配置,极大的简化了操作过程! 软件功能 1.对在线设备进行搜索 ...

  2. 【Spring Cloud学习之三】负载均衡

    环境 eclipse 4.7 jdk 1.8 Spring Boot 1.5.2 Spring Cloud 1.2 主流的负载均衡技术有nginx.LVS.HAproxy.F5,Spring Clou ...

  3. ifcopenshell在VS2015下的编译

    源起 今天使用 IfcOpenShell的IfcConvert ,因为是开源的所以就想自己编译下,编译过程中遇到不少问题,因此记录下来 什么是IfcOpenShell? IfcOpenShell是一个 ...

  4. 【ARM-Linux开发】【CUDA开发】NVIDIA TEGRA X1:LINUX驱动程序包多媒体用户指南

    NVIDIA TEGRA X1:LINUX驱动程序包多媒体用户指南 转载请注明作者和出处:http://blog.csdn.net/u011475210 嵌入式平台:NVIDIA Jetson TX1 ...

  5. Three.js场景的基本组件

    1.场景Scene THREE.Scene被称为场景图,可以用来保存所有图形场景的必要信息.每个添加到Scene的对象,包括Scene自身都继承自名为THREE.Object3D对象.Scene不仅仅 ...

  6. Django-10-分页组件

    1. Django内置分页 from django.shortcuts import render from django.core.paginator import Paginator, Empty ...

  7. golang 之文件操作

    文件操作要理解一切皆文件. Go 在 os 中提供了文件的基本操作,包括通常意义的打开.创建.读写等操作,除此以外为了追求便捷以及性能上,Go 还在 io/ioutil 以及 bufio 提供一些其他 ...

  8. IDEA设置虚拟机参数

    第一步:打开“Run->Edit Configurations”菜单 第二步:选择“VM Options”选项,输入你要设置的VM参数 第三步:点击“OK”.“Apply”后设置完成

  9. Entity Framework Codefirst的配置步骤

    Entity Framework Codefirst的配置步骤: (1) 安装命令: install-package entityframework (2) 创建实体类,注意virtual关键字在导航 ...

  10. WebRTC 入门教程(二)| WebRTC信令控制与STUN/TURN服务器搭建

    WebRTC 入门教程(二)| WebRTC信令控制与STUN/TURN服务器搭建 四月 4, 2019 作者:李超,音视频技术专家.本文首发于 RTC 开发者社区,欢迎在社区留言与作者交流. htt ...