第一次,调了很久。它本来已经很OK了,同时适用CPU和GPU,且可正常运行的。

为了用于性能测试,主要改了三点:

一,每一批次显示处理时间。

二,本地加载测试数据。

三,兼容LINUX和WIN

本地加载测试数据时,要注意是用将两个pt文件,放在processed目录下,raw目录不要即可。

训练数据的定义目录是在当前目录 data/MNIST/processed目录下。

我自己弄了个下载:

http://u.163.com/2FUm6N1L  提取码: XJpmqUoR

只能下载20次,过了可在此留言。

  1. import os
  2. import timeit
  3. import torch # pytorch 最基本模块
  4. import torch.nn as nn # pytorch中最重要的模块,封装了神经网络相关的函数
  5. import torch.nn.functional as F # 提供了一些常用的函数,如softmax
  6. import torch.optim as optim # 优化模块,封装了求解模型的一些优化器,如Adam SGD
  7. from torch.optim import lr_scheduler # 学习率调整器,在训练过程中合理变动学习率
  8. from torchvision import transforms #pytorch 视觉库中提供了一些数据变换的接口
  9. from torchvision import datasets #pytorch 视觉库提供了加载数据集的接口
  10.  
  11. DATA_DIR = os.path.join(os.getcwd(),"data")
  12. # 预设网络超参数 (所谓超参数就是可以人为设定的参数
  13.  
  14. BATCH_SIZE= 64 # 由于使用批量训练的方法,需要定义每批的训练的样本数目
  15.  
  16. EPOCHS=3 # 总共训练迭代的次数
  17.  
  18. # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
  19. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  20.  
  21. learning_rate = 0.1 # 设定初始的学习率
  22.  
  23. # 加载训练集
  24. train_loader = torch.utils.data.DataLoader(
  25. datasets.MNIST(DATA_DIR, train=True,
  26. transform=transforms.Compose([
  27. transforms.ToTensor(),
  28. transforms.Normalize(mean=(0.5,), std=(0.5,)) # 数据规范化到正态分布
  29. ])),
  30. batch_size=BATCH_SIZE, shuffle=True) # 指明批量大小,打乱,这是处于后续训练的需要。
  31.  
  32. test_loader = torch.utils.data.DataLoader(
  33. datasets.MNIST(DATA_DIR, train=False, transform=transforms.Compose([
  34. transforms.ToTensor(),
  35. transforms.Normalize((0.5,), (0.5,))
  36. ])),
  37. batch_size=BATCH_SIZE, shuffle=True)
  38.  
  39. # 设计模型
  40. class ConvNet(nn.Module):
  41. def __init__(self):
  42. super(ConvNet, self).__init__()
  43. # 提取特征层
  44. self.features = nn.Sequential(
  45. # 卷积层
  46. # 输入图像通道为 1,因为我们使用的是黑白图,单通道的
  47. # 输出通道为32(代表使用32个卷积核),一个卷积核产生一个单通道的特征图
  48. # 卷积核kernel_size的尺寸为 3 * 3,stride 代表每次卷积核的移动像素个数为1
  49. # padding 填充,为1代表在图像长宽都多了两个像素
  50. nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size=3, stride=1, padding=1),
  51.  
  52. # 批量归一化,跟上一层的out_channels大小相等,以下的通道规律也是必须要对应好的
  53. nn.BatchNorm2d(num_features = 32),
  54.  
  55. # 激活函数,inplace=true代表直接进行运算
  56. nn.ReLU(inplace=True),
  57. nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
  58. nn.BatchNorm2d(32),
  59. nn.ReLU(inplace=True),
  60.  
  61. # 最大池化层
  62. # kernel_size 为2 * 2的滑动窗口
  63. # stride为2,表示每次滑动距离为2个像素
  64. # 经过这一步,图像的大小变为1/4,即 28 * 28 -》 14 * 14
  65. nn.MaxPool2d(kernel_size=2, stride=2),
  66. nn.Conv2d(32, 64, kernel_size=3, padding=1),
  67. nn.BatchNorm2d(64),
  68. nn.ReLU(inplace=True),
  69. nn.Conv2d(64, 64, kernel_size=3, padding=1),
  70. nn.BatchNorm2d(64),
  71. nn.ReLU(inplace=True),
  72. nn.MaxPool2d(kernel_size=2, stride=2) # 14 * 14 -》 7 * 7
  73. )
  74. # 分类层
  75. self.classifier = nn.Sequential(
  76. # Dropout层
  77. # p = 0.5 代表该层的每个权重有0.5的可能性为0
  78. nn.Dropout(p = 0.5),
  79. # 这里是通道数64 * 图像大小7 * 7,然后输入到512个神经元中
  80. nn.Linear(64 * 7 * 7, 512),
  81. nn.BatchNorm1d(512),
  82. nn.ReLU(inplace=True),
  83. nn.Dropout(p = 0.5),
  84. nn.Linear(512, 512),
  85. nn.BatchNorm1d(512),
  86. nn.ReLU(inplace=True),
  87. nn.Dropout(p = 0.5),
  88. nn.Linear(512, 10),
  89. )
  90.  
  91. def forward(self, x):
  92. # 经过特征提取层
  93. x = self.features(x)
  94. # 输出结果必须展平成一维向量
  95. x = x.view(x.size(0), -1)
  96. x = self.classifier(x)
  97. return x
  98.  
  99. # 初始化模型
  100. ConvModel = ConvNet().to(DEVICE)
  101. # 定义交叉熵损失函数
  102. criterion = nn.CrossEntropyLoss().to(DEVICE)
  103. # 定义模型优化器
  104. optimizer = torch.optim.Adam(ConvModel.parameters(), lr = learning_rate)
  105. # 定义学习率调度器
  106. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)
  107.  
  108. def train(num_epochs,_model, _device, _train_loader, _optimizer, _lr_scheduler):
  109. _model.train()
  110. _lr_scheduler.step()
  111. for epoch in range(num_epochs):
  112. start = end = 0
  113. # 从迭代器抽取图片和标签
  114. for i, (images, labels) in enumerate(_train_loader):
  115. if (i + 1) % 100 == 1:
  116. start = timeit.default_timer()
  117. samples = images.to(_device)
  118. labels = labels.to(_device)
  119. #此时样本是一批图片,在CNN的输入中,我们需要将其变为四维,
  120. # reshape第一个-1 代表自动计算批量图片的数目n
  121. # 最后reshape得到的结果就是n张图片,每一张图片都是单通道的28 * 28,得到四维张量
  122. output = _model(samples.reshape(-1, 1, 28, 28))
  123.  
  124. # 计算损失函数值
  125. loss = criterion(output, labels)
  126.  
  127. # 优化器内部参数梯度必须变为0
  128. optimizer.zero_grad()
  129.  
  130. # 损失值后向传播
  131. loss.backward()
  132.  
  133. # 更新模型参数
  134. optimizer.step()
  135.  
  136. if (i + 1) % 100 == 0:
  137. end = timeit.default_timer()
  138. print("Epoch:{}/{}, Time:{}s, step:{}, loss:{:.4f}".format(epoch+1, num_epochs, end-start, i + 1, loss.item()))
  139.  
  140. def test(_test_loader, _model, _device):
  141. _model.eval() # 设置模型进入预测模式 evaluation
  142. loss = 0
  143. correct = 0
  144.  
  145. with torch.no_grad(): #如果不需要 backward更新梯度,那么就要禁用梯度计算,减少内存和计算资源浪费。
  146. for data, target in _test_loader:
  147. data, target = data.to(_device), target.to(_device)
  148. output = ConvModel(data.reshape(-1, 1, 28, 28))
  149. loss += criterion(output, target).item() # 添加损失值
  150. pred = output.data.max(1, keepdim=True)[1] # 找到概率最大的下标,为输出值
  151. correct += pred.eq(target.data.view_as(pred)).cpu().sum() # .cpu()是将参数迁移到cpu上来。
  152.  
  153. loss /= len(_test_loader.dataset)
  154.  
  155. print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
  156. loss, correct, len(_test_loader.dataset),
  157. 100. * correct / len(_test_loader.dataset)))
  158.  
  159. for epoch in range(1, EPOCHS + 1):
  160. train(epoch, ConvModel, DEVICE, train_loader, optimizer, exp_lr_scheduler)
  161. test(test_loader,ConvModel, DEVICE)
  162. test(train_loader,ConvModel, DEVICE)

一套兼容win和Linux的PyTorch训练MNIST的算法代码(CNN)的更多相关文章

  1. php中路径斜杠的应用,兼容win与linux

    更多内容推荐微信公众号,欢迎关注: PHP中斜杠的运用 兼容win和linux 使用常量:DIRECTORY_SEPARATOR如:"www".DIRECTORY_SEPARATO ...

  2. 跨平台设置NODE_ENV(兼容win和linux)

    通过NODE_ENV可以来设置环境变量(默认值为development).一般我们通过检查这个值来分别对开发环境和生产环境下做不同的处理.可以在命令行中通过下面的方式设置这个值: linux & ...

  3. 用Pytorch训练MNIST分类模型

    本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...

  4. Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    我用过的编辑器不少,真不少- 但却没有哪款让我特别心仪的,直到我遇到了 Sublime Text 2 !如果说“神器”是我能给予一款软件最高的评价,那么我很乐意为它封上这么一个称号.它小巧绿色且速度非 ...

  5. [转载]Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    代码编辑器或者文本编辑器,对于程序员来说,就像剑与战士一样,谁都想拥有一把可以随心驾驭且锋利无比的宝剑,而每一位程序员,同样会去追求最适合自己的强大.灵活的编辑器,相信你和我一样,都不会例外. 我用过 ...

  6. Java文件夹操作,判断多级路径是否存在,不存在就创建(包括windows和linux下的路径字符分析),兼容Windows和Linux

    兼容windows和linux. 分析: 在windows下路径有以下表示方式: (标准)D:\test\1.txt (不标准,参考linux)D:/test/1.txt 然后在java中,尤其使用F ...

  7. paip兼容windows与linux的java类根目录路径的方法

    paip兼容windows与linux的java类根目录路径的方法 1.只有 pathx.class.getResource("")或者pathx.class.getResourc ...

  8. redhat 安装配置samba实现win共享linux主机目录

    [转]http://blog.chinaunix.net/uid-26642180-id-3135941.html redhat 安装配置samba实现win共享linux主机目录 2012-03-1 ...

  9. Win和Linux查看端口和杀死进程

    title: Win和Linux查看端口和杀死进程 date: 2017-7-30 tags: null categories: Linux --- 本文介绍Windows和Linux下查看端口和杀死 ...

随机推荐

  1. MyBatis踩坑之SQLProvider转义字符被删除问题

    目录 踩坑背景 问题描述 原因追踪 解决方案 方法一 方法二 踩坑背景 项目架构:Spring Boot + MyBatis + MySQL. 使用MyBatis作为ORM框架,jdbc驱动使用的是m ...

  2. 「杂录」CSP-S 2019 爆炸记&题解

    考试状况 \(Day1\) \(8:30\) 解压,先打个含头文件和\(freopen\)的模板程序,准备做题. \(8:35\) 开题,心想着按顺序做吧,毕竟难度一般是按顺序排的. 第一题,一眼看过 ...

  3. 使用SSM搭建一个简单的crud项目

    使用SSM完成增删查改 前端使用到的技术:ajax,json,bootstrap等 完整项目地址:点这里GitHub 项目地址,可以在线访问 这一章节主要搭建SSM的环境. SpringMVC Spr ...

  4. [BZOJ4382][POI2015]Podział naszyjnika (神奇HASH)

    [问题描述]    长度为n 的一串项链,每颗珠子是K 种颜色之一.第i 颗与第i-1,i+1 颗珠子相邻,第n 颗与第1 颗也相邻.    切两刀,把项链断成两条链.要求每种颜色的珠子只能出现在其中 ...

  5. C++分治策略实现快速排序

    问题描述: 给定一个未知顺序的n个元素组成的数组,现要利用快速排序算法对这n个元素进行非递减排序. 细节须知: (1)代码实现了利用递归对数组进行快速排序,其中limit为从已有的随机数文件中输入的所 ...

  6. LeetCode 5198. 丑数 III(Java)容斥原理和二分查找

    题目链接:5198. 丑数 III 请你帮忙设计一个程序,用来找出第 n 个丑数. 丑数是可以被 a 或 b 或 c 整除的 正整数. 示例 1: 输入:n = 3, a = 2, b = 3, c ...

  7. JMX远程监控JVM

    远程监控JVM状态需要在JVM启动的时候需要加上一段代码开启这个功能.(以下全部以ubuntu-14-04-server.jdk1.8.tomcat7.0环境为基础) 配置的时候分两种情况:1.无需配 ...

  8. 【题解】Luogu P5400 [CTS2019]随机立方体

    原题传送门 毒瘤计数题 我们设\(dp_i\)表示至少有\(i\)个极大数字的概率,\(ans_i\)表示恰好有\(i\)个极大数的概率,\(mi=Min(n,m,l)\) 易知: \[dp_i=\s ...

  9. Harbor配置自签名证书,docker login+web https访问,helm chart推送应用

    注:高版本(14以上)docker执行login命令,默认使用https,且harbor必须使用域名,只是用ip访问是不行的. 假设使用的网址是:www.harbor.mobi,本机ip是192.16 ...

  10. SpringCloud整合sleuth,使用zipkin时不显示服务

    转载于:https://www.cnblogs.com/Dandwj/p/11179141.html 原文地址:https://blog.csdn.net/weixin_30416497/articl ...