1. prefetch_generator

使用 prefetch_generator库 在后台加载下一batch的数据,原本PyTorch默认的DataLoader会创建一些worker线程来预读取新的数据,但是除非这些线程的数据全部都被清空,这些线程才会读下一批数据。使用prefetch_generator,我们可以保证线程不会等待,每个线程都总有至少一个数据在加载。

  • 安装

    pip install prefetch_generator
  • 使用

    之前加载数据集的正确方式是使用torch.utils.data.DataLoader,现在我们只要利用这个库,新建个DataLoaderX类继承DataLoader并重写__iter__方法即可

    from torch.utils.data import DataLoader
    from prefetch_generator import BackgroundGenerator class DataLoaderX(DataLoader): def __iter__(self):
    return BackgroundGenerator(super().__iter__())

    之后这样用:

    train_dataset = MyDataset(".........")
    train_loader = DataLoaderX(dataset=train_dataset,
    batch_size=batch_size, num_workers=4, shuffle=shuffle)

2. Apex

2.1 安装

  1. 克隆源代码
git clone https://github.com/NVIDIA/apex

可以先下载到码云,再下载到本地

  1. 安装apex
cd apex
python setup.py install

最好打开PyCharm的终端进行安装,这样实在Anaconda的环境里安装了

  1. 删除刚刚clone下来的apex文件夹,然后重启PyCharm

【注意】安装PyTorch和cuda时注意版本对应,要按照正确流程安装

  1. 测试安装成功
from apex import amp

如果导入不报错说明安装成功

2.2 使用

from apex import amp  # 这个必须的,其他的导包省略了

train_dataset = MyDataset("......")
train_loader = DataLoader(dataset=train_dataset, batch_size=2, num_workers=4, shuffle=True) model = MyNet().to(device) # 创建模型 criterion = nn.MSELoss() # 定义损失函数 optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.00001) # 优化器 net, optimizer = amp.initialize(net, optimizer, opt_level="O1") # 这一步很重要 # 学习率衰减
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer, mode="min",factor=0.1, patience=3,
verbose=False,cooldown=0, min_lr=0.0, eps=1e-7) for epoch in range(epochs):
net.train() # 训练模式 train_loss_epoch = [] # 记录一个epoch内的训练集每个batch的loss
test_loss_epoch = [] # 记录一个epoch内测试集的每个batch的loss for i, data in enumerate(train_loader):
# forward
x, y = data
x = x.to(device)
y = y.to(device) outputs = net(x) # backward
optimizer.zero_grad() loss = criterion(outputs, labels) # 这一步也很重要
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() # 更新权重
optimizer.step() scheduler.step(1) # 更新学习率。每1步更新一次
  • 主要是添加了三行代码
  • scaled_loss 是将原loss放大了,所以要保存loss应该保存之前的值,这种放大防止梯度消失

考察amp.initialize(net, optimizer, opt_level="O1")opt_level参数

  • opt_level=O0(base)

    表示的是当前执行FP32训练,即正常的训练

  • opt_level=O1(推荐)

    表示的是当前使用部分FP16混合训练

  • opt_level=O2

    表示的是除了BN层的权重外,其他层的权重都使用FP16执行训练

  • opt_level=O3

    表示的是默认所有的层都使用FP16执行计算,当keep_batch norm_fp32=True,则会使用cudnn执行BN层的计算,该优化等级能够获得最快的速度,但是精度可能会有一些较大的损失

一般我们用O1级别就行,最多O2,注意,是不是

『PyTorch』屌丝的PyTorch玩法的更多相关文章

  1. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  2. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  3. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  4. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  5. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

  6. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  7. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  8. 『Python』__getattr__()特殊方法

    self的认识 & __getattr__()特殊方法 将字典调用方式改为通过属性查询的一个小class, class Dict(dict): def __init__(self, **kw) ...

  9. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

随机推荐

  1. noip15

    童话故事专场 T1 首先,dead line 是一条直线,而不是线段.考试的时候一直以为是线段,那么横竖共有n+m条,考虑斜着的,斜着的交点为有穷的,则需要满足斜率不同,那么只需要统计一边的,再乘2就 ...

  2. Java:学习什么是多线程

    线程是什么 进程是对CPU的抽象,而线程更细化了进程的运行流程 先看一下这个图 线程和进程的关系有 进程中就是线程在执行,所有(主)线程执行完了进程也就结束了 多个线程从1秒钟是同时运行完成,从1纳秒 ...

  3. wpf 富文本编辑器richtextbox的简单用法

    最近弄得一个小软件,需要用到富文本编辑器,richtextbox,一开始以为是和文本框一样的用法,但是实践起来碰壁之后才知道并不简单. richtextbox 类似于Word,是一个可编辑的控件.结构 ...

  4. 【mysql】mysql逻辑框架简介及show profile说明

    1.mysql逻辑框架简介 和其它数据库相比,MySQL 有点与众不同,它的架构可以在多种不同场景中应用并发挥良好作用.主要体现在存储引擎的架构上,插件式的存储引擎架构将查询处理和其它的系统任务以及数 ...

  5. mfc HackerTools拖动文件

    VOID DragAcceptFiles(          HWND hWnd,    BOOL fAccept); 这个函数的调用,表示你要让某个窗体能够接受文件的拖入.第一个参数指定是哪个窗口, ...

  6. 初探Spring Security

    Spring Security 简介 Spring Security是Spring家族中的一个组成框架,具有强大且高度可定制的身份验证和访问控制功能,致力于为Java应用程序提供身份的验证和授权 (先 ...

  7. canvas二次贝塞尔&三次贝塞尔操作实例

    Canvas Quadratic Curve Example canvas = document.getElementById("canvas"); ctx = canvas.ge ...

  8. Linux系统的日志管理、时间同步、延迟命令at

    方便查看和管理 /var/log/messages ?系统服务及日志,包括服务的信息,报错等等 /var/log/secure ? ? ? ? 系统认证信息日志 /var/log/maillog ? ...

  9. VS Code闪现,巨头纷纷入局的Web IDE缘何崛起?

    我发了,我装的. 就在前几天,微软简短的发布了Visual Studio Code for the Web 的公告,而没过一阵,这则公告就被删除了,现在点经相关内容已经是404状态了.虽然公告的内容已 ...

  10. Tars | 第3篇 Tars中期汇报测试文档(Java语言实现Subset路由规则)

    目录 前言 1. 任务介绍 2. 测试模拟方案 2.0 *前置工作 2.1 添加路由规则 2.2 添加存活节点 2.3 [输出]遍历输出当前存活节点 2.4 [核心]对存活节点按subset规则过滤 ...