[Pytorch]基于混和精度的模型加速
这篇博客是在pytorch中基于apex使用混合精度加速的一个偏工程的描述,原理层面的解释并不是这篇博客的目的,不过在参考部分提供了非常有价值的资料,可以进一步研究。
一个关键原则:“仅仅在权重更新的时候使用fp32,耗时的前向和后向运算都使用fp16”。其中的一个技巧是:在反向计算开始前,将dloss乘上一个scale,人为变大;权重更新前,除去scale,恢复正常值。目的是为了减小激活gradient下溢出的风险。
apex是nvidia的一个pytorch扩展,用于支持混合精度训练和分布式训练。在之前的博客中,神经网络的Low-Memory技术梳理了一些low-memory技术,其中提到半精度,比如fp16。apex中混合精度训练可以通过简单的方式开启自动化实现,组里同学交流的结果是:一般情况下,自动混合精度训练的效果不如手动修改。分布式训练中,有社区同学心心念念的syncbn的支持。关于syncbn,在去年做CV的时候,我们就有一些来自民间的尝试,不过具体提升还是要考虑具体任务场景。
那么问题来了,如何在pytorch中使用fp16混合精度训练呢?
第零:混合精度训练相关的参数
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
第一:模型参数转换为fp16
nn.Module中的half()方法将模型中的float32转化为float16,实现的原理是遍历所有tensor,而float32和float16都是tensor的属性。也就是说,一行代码解决,如下:
model.half()
第二:修改优化器
在pytorch下,当使用fp16时,需要修改optimizer。类似代码如下(代码参考这里):
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
第三:backward时做对应修改
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
第四:学习率修改
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
根据参考3,值得重述一些重要结论:
(1)深度学习训练使用16bit表示/运算正逐渐成为主流。
(2)低精度带来了性能、功耗优势,但需要解决量化误差(溢出、舍入)。
(3)常见的避免量化误差的方法:为权重保持高精度(fp32)备份;损失放大,避免梯度的下溢出;一些特殊层(如BatchNorm)仍使用fp32运算。
参考资料:
1.nv官方repo给了一些基于pytorch的apex加速的实现
实现是基于fairseq实现的,可以直接对比代码1-apex版和代码2-非apex版(fairseq官方版),了解是如何基于apex实现加速的。
按图索骥,可以get到很多更加具体地内容。
感谢团队同学推荐。
[Pytorch]基于混和精度的模型加速的更多相关文章
- 【神经网络篇】--基于数据集cifa10的经典模型实例
一.前述 本文分享一篇基于数据集cifa10的经典模型架构和代码. 二.代码 import tensorflow as tf import numpy as np import math import ...
- 基于MATLAB搭建的DDS模型
基于MATLAB搭建的DDS模型 说明: 累加器输出ufix_16_6数据,通过cast切除小数部分,在累加的过程中,带小数进行运算最后对结果进行处理,这样提高了计算精度. 关于ROM的使用: 直接设 ...
- StartDT AI Lab | 视觉智能引擎之算法模型加速
通过StartDT AI Lab专栏之前多篇文章叙述,相信大家已经对计算机视觉技术及人工智能算法在奇点云AIOT战略中的支撑作用有了很好的理解.同样,这种业务牵引,技术覆盖的模式也收获了市场的良好反响 ...
- Atitit 基于meta的orm,提升加速数据库相关应用的开发
Atitit 基于meta的orm,提升加速数据库相关应用的开发 1.1. Overview概论1 1.2. Function & Feature功能特性1 1.2.1. meta api2 ...
- 基于git的源代码管理模型——git flow
基于git的源代码管理模型--git flow A successful Git branching model
- 详解Linux2.6内核中基于platform机制的驱动模型 (经典)
[摘要]本文以Linux 2.6.25 内核为例,分析了基于platform总线的驱动模型.首先介绍了Platform总线的基本概念,接着介绍了platform device和platform dri ...
- 基于R语言的ARIMA模型
A IMA模型是一种著名的时间序列预测方法,主要是指将非平稳时间序列转化为平稳时间序列,然后将因变量仅对它的滞后值以及随机误差项的现值和滞后值进行回归所建立的模型.ARIMA模型根据原序列是否平稳以及 ...
- 基于PaddlePaddle的语义匹配模型DAM,让聊天机器人实现完美回复 |
来源商业新知网,原标题:让聊天机器人完美回复 | 基于PaddlePaddle的语义匹配模型DAM 语义匹配 语义匹配是NLP的一项重要应用.无论是问答系统.对话系统还是智能客服,都可以认为是问题和回 ...
- 第13章 TCP编程(4)_基于自定义协议的多线程模型
7. 基于自定义协议的多线程模型 (1)服务端编程 ①主线程负责调用accept与客户端连接 ②当接受客户端连接后,创建子线程来服务客户端,以处理多客户端的并发访问. ③服务端接到的客户端信息后,回显 ...
随机推荐
- mysql官网下载驱动包
[转载]原文链接:http://blog.csdn.net/u010523770/article/details/52240946 驱动官网下载地址:http://dev.mysql.com/down ...
- 【JZOJ3216】【SDOI2013】淘金
╰( ̄▽ ̄)╭ 小 Z在玩一个 叫做<淘金者>的游戏.游戏的世界是一个 二维坐标 .X轴.Y轴坐标范围均为1..N.初始的时候,所有的整数坐标点上均有一块金子,共 N*N 块. 一阵风吹过 ...
- MySQL——自定义[存储]函数、触发器
一. 编程基础 1) 结束符 2) 代码块 Begin 相当于 { end; 相当于 } 1. 变量 系统变量 Show variables; 查看系统变量sql_ ...
- Mathematica 和 MATLAB、Maple 并称为三大数学软件
Mathematica是一款科学计算软件,很好地结合了数值和符号计算引擎.图形系统.编程语言.文本系统.和与其他应用程序的高级连接.很多功能在相应领域内处于世界领先地位,它也是使用最广泛的数学软件之一 ...
- 【JZOJ4763】【NOIP2016提高A组模拟9.7】旷野大计算
题目描述 输入 输出 样例输入 5 5 9 8 7 8 9 1 2 3 4 4 4 1 4 2 4 样例输出 9 8 8 16 16 数据范围 解法 离线莫队做法 考虑使用莫队,但由于在删数的时候难以 ...
- Directx教程(27) 简单的光照模型(6)
原文:Directx教程(27) 简单的光照模型(6) 从myTutorialD3D11_15到myTutorialD3D11_19的工程中,我们都只有一个光源,光源的位置在LightCla ...
- 【记录Bug】 This is probably not a problem with npm. There is likely additional logging output above.
一个eslint的错误 我的报错如下 $ npm install > node-sass@4.11.0 install C:\Users\Administrator\Desktop\forGit ...
- 外贸电子商务网站之Prestashop 语言包安装
prestashop添加语言-下载语言包 我们找到中文简体(Chinese-Simplified)一行,点击最后一栏的下载(Download)按钮,我们点击下载,可以下到一个以语言的 ISO为文件名, ...
- bzoj1800 飞行棋
脑筋急转弯. 提示:矩形矩形矩形.O(n)O(n)O(n). 再提示:直角. 再提示:直径. 代码: //Serene #include<algorithm> #include<io ...
- 使用php simple html dom parser解析html标签
转自:http://www.blhere.com/1243.html 使用php simple html dom parser解析html标签 用了一下 PHP Simple HTML DOM Par ...