[Pytorch框架] 1.6 训练一个分类器
训练一个分类器
上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重。
你现在可能在想下一步。
关于数据?
一般情况下处理图像、文本、音频和视频数据时,可以使用标准的Python包来加载数据到一个numpy数组中。
然后把这个数组转换成 torch.*Tensor
。
- 图像可以使用 Pillow, OpenCV
- 音频可以使用 scipy, librosa
- 文本可以使用原始Python和Cython来加载,或者使用 NLTK或
SpaCy 处理
特别的,对于图像任务,我们创建了一个包
torchvision
,它包含了处理一些基本图像数据集的方法。这些数据集包括
Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision
还包含了图像转换器,
torchvision.datasets
和 torch.utils.data.DataLoader
。
torchvision
包不仅提供了巨大的便利,也避免了代码的重复。
在这个教程中,我们使用CIFAR10数据集,它有如下10个类别
:‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是
3x32x32大小的,即,3颜色通道,32x32像素。
训练一个图像分类器
依次按照下列顺序进行:
使用
torchvision
加载和归一化CIFAR10训练集和测试集定义一个卷积神经网络
定义损失函数
在训练集上训练网络
在测试集上测试网络
读取和归一化 CIFAR10
使用torchvision
可以非常容易地加载CIFAR10。
import torch
import torchvision
import torchvision.transforms as transforms
torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
100%|███████████████████████████████████████████████████████████████████████████████▉| 170M/170M [20:39<00:00, 155kB/s]
Files already downloaded and verified
我们展示一些训练图像。
import matplotlib.pyplot as plt
import numpy as np
# 展示图像的函数
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()
# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
171MB [20:51, 155kB/s]
cat car cat ship
- 定义一个卷积神经网络
从之前的神经网络一节复制神经网络代码,并修改为输入3通道图像。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
- 定义损失函数和优化器
我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
- 训练网路
有趣的时刻开始了。
我们只需在数据迭代器上循环,将数据输入给网络,并优化。
for epoch in range(2): # 多批次循环
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入
inputs, labels = data
# 梯度置0
optimizer.zero_grad()
# 正向传播,反向传播,优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印状态信息
running_loss += loss.item()
if i % 2000 == 1999: # 每2000批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
- 在测试集上测试网络
我们在整个训练集上进行了2次训练,但是我们需要检查网络是否从数据集中学习到有用的东西。
通过预测神经网络输出的类别标签与实际情况标签进行对比来进行检测。
如果预测正确,我们把该样本添加到正确预测列表。
第一步,显示测试集中的图片并熟悉图片内容。
dataiter = iter(testloader)
images, labels = dataiter.next()
# 显示图片
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
GroundTruth: cat ship ship plane
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7N9Ui0NR-1612164760383)(output_14_1.png)]
让我们看看神经网络认为以上图片是什么。
outputs = net(images)
输出是10个标签的能量。
一个类别的能量越大,神经网络越认为它是这个类别。所以让我们得到最高能量的标签。
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
Predicted: plane plane plane plane
结果看来不错。
接下来让看看网络在整个测试集上的结果如何。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
Accuracy of the network on the 10000 test images: 9 %
结果看起来不错,至少比随机选择要好,随机选择的正确率为10%。
似乎网络学习到了一些东西。
在识别哪一个类的时候好,哪一个不好呢?
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 99 %
Accuracy of car : 0 %
Accuracy of bird : 0 %
Accuracy of cat : 0 %
Accuracy of deer : 0 %
Accuracy of dog : 0 %
Accuracy of frog : 0 %
Accuracy of horse : 0 %
Accuracy of ship : 0 %
Accuracy of truck : 0 %
下一步?
我们如何在GPU上运行神经网络呢?
在GPU上训练
把一个神经网络移动到GPU上训练就像把一个Tensor转换GPU上一样简单。并且这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 确认我们的电脑支持CUDA,然后显示CUDA信息:
print(device)
本节的其余部分假定device
是CUDA设备。
然后这些方法将递归遍历所有模块并将模块的参数和缓冲区
转换成CUDA张量:
net.to(device)
记住:inputs, targets 和 images 也要转换。
inputs, labels = inputs.to(device), labels.to(device)
为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。
实践:
尝试增加你的网络的宽度(第一个nn.Conv2d
的第2个参数,第二个nn.Conv2d
的第一个参数,它们需要是相同的数字),看看你得到了什么样的加速。
实现的目标:
- 深入了解了PyTorch的张量库和神经网络
- 训练了一个小网络来分类图片
译者注:后面我们教程会训练一个真正的网络,使识别率达到90%以上。
多GPU训练
如果你想使用所有的GPU得到更大的加速,
请查看数据并行处理。
下一步?
- :doc:
训练神经网络玩电子游戏 </intermediate/reinforcement_q_learning>
在ImageNet上训练最好的ResNet
使用对抗生成网络来训练一个人脸生成器
使用LSTM网络训练一个字符级的语言模型
更多示例
更多教程
在论坛上讨论PyTorch
Slack上与其他用户讨论
[Pytorch框架] 1.6 训练一个分类器的更多相关文章
- PyTorch Tutorials 4 训练一个分类器
%matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...
- 【PyTorch深度学习60分钟快速入门 】Part4:训练一个分类器
太棒啦!到目前为止,你已经了解了如何定义神经网络.计算损失,以及更新网络权重.不过,现在你可能会思考以下几个方面: 0x01 数据集 通常,当你需要处理图像.文本.音频或视频数据时,你可以使用标准 ...
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 使用weka训练一个分类器
1 训练集数据 1.1 csv格式 5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setos ...
- Fine-tuning Convolutional Neural Networks for Biomedical Image Analysis: Actively and Incrementally如何使用尽可能少的标注数据来训练一个效果有潜力的分类器
作者:AI研习社链接:https://www.zhihu.com/question/57523080/answer/236301363来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载 ...
- 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器
迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...
- 如何使用Pytorch迅速实现Mnist数据及分类器
一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...
- 【chainer框架】【pytorch框架】
教程: https://bennix.github.io/ https://bennix.github.io/blog/2017/12/14/chain_basic/ https://bennix.g ...
- pytorch:EDSR 生成训练数据的方法
Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
随机推荐
- Class 'dmstr\web\AdminLteAsset' not found
Yii2出现 Class 'dmstr\web\AdminLteAsset' not found 报错 1.检查下是不是vendor从其他地方复制过来的 2.检查根目录composer.json 中 ...
- MySql 入门——日期计算
MySQL自带的日期函数TIMESTAMPDIFF计算两个日期相差的秒数.分钟数.小时数.天数.周数.季度数.月数.年数,当前日期增加或者减少一天.一周等等 SELECT TIMESTAMPDIFF( ...
- Vue的官方脚手架 Vue-cli 安装使用解析
------------恢复内容开始------------ 1.首先什么是vue-cli 可以知道Vue-cli是一个官方提供的脚手架,主要作用是用来快速搭建Vue的项目模板,可以预先定义好项目的结 ...
- 还不知道如何在java中终止一个线程?快来,一文给你揭秘
目录 简介 Thread.stop被禁用之谜 怎么才能安全? 捕获异常之后的处理 总结 简介 工作中我们经常会用到线程,一般情况下我们让线程执行就完事了,那么你们有没有想过如何去终止一个正在运行的线程 ...
- subline Text 设置中文
subline Text是一个轻量级的文本编辑器,类似于记事本,不过它拥有代码高亮,简约好看的主题. 下载地址:https://download.sublimetext.com/sublime_tex ...
- Spring--AOP切入点表达式
AOP工作流程 能够与做代理的那个类匹配得上的话,叫做代理对象,否则为原始对象. (SpringAOP的本质:代理模式) AOP的切入点表达式 切入点表达式描述的标准格式 描述方式一:定位到某某包下的 ...
- 关于如何编写好金融科技客户端SDK的思考
引言 回想起来,我在目前的团队(金融科技领域)待了有很长一段时间了,一直在做SDK研发,平时工作中经历过大刀阔斧一蹴而就的喜悦,也经历过被一个问题按在地上摩擦,无奈"废寝忘食"的不 ...
- uniapp微信小程序解析详情页的四种方法
一.用微信文档提供的RICH-TEXT 官方文档:微信文档rich-text 这种是直接使用: <!-->content是API获取的html代码</--> <rich- ...
- AOP的九点核心概念和作用
AOP为Aspect Oriented Programming的缩写,意为:面向切面编程,通过预编译方式和运行期间动态代理实现程序功能的统一维护的一种技术.AOP是OOP的延续,是软件开发中的一个热点 ...
- 磁盘IO 基本常识
计算机硬件性能在过去十年间的发展普遍遵循摩尔定律,通用计算机的 CPU主频早已超过3GHz,内存也进入了普及DDR4的时代.然而传统硬盘虽然在存储容量上增长迅速,但是在读写性能上并无明显提升,同时SS ...