这篇博客是在pytorch中基于apex使用混合精度加速的一个偏工程的描述,原理层面的解释并不是这篇博客的目的,不过在参考部分提供了非常有价值的资料,可以进一步研究。

一个关键原则:“仅仅在权重更新的时候使用fp32,耗时的前向和后向运算都使用fp16”。其中的一个技巧是:在反向计算开始前,将dloss乘上一个scale,人为变大;权重更新前,除去scale,恢复正常值。目的是为了减小激活gradient下溢出的风险。

apex是nvidia的一个pytorch扩展,用于支持混合精度训练和分布式训练。在之前的博客中,神经网络的Low-Memory技术梳理了一些low-memory技术,其中提到半精度,比如fp16。apex中混合精度训练可以通过简单的方式开启自动化实现,组里同学交流的结果是:一般情况下,自动混合精度训练的效果不如手动修改。分布式训练中,有社区同学心心念念的syncbn的支持。关于syncbn,在去年做CV的时候,我们就有一些来自民间的尝试,不过具体提升还是要考虑具体任务场景。

那么问题来了,如何在pytorch中使用fp16混合精度训练呢?

第零:混合精度训练相关的参数

  1. parser.add_argument('--fp16',
  2. action='store_true',
  3. help="Whether to use 16-bit float precision instead of 32-bit")
  4. parser.add_argument('--loss_scale',
  5. type=float, default=0,
  6. help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
  7. "0 (default value): dynamic loss scaling.\n"
  8. "Positive power of 2: static loss scaling value.\n")

第一:模型参数转换为fp16

nn.Module中的half()方法将模型中的float32转化为float16,实现的原理是遍历所有tensor,而float32和float16都是tensor的属性。也就是说,一行代码解决,如下:

  1. model.half()

第二:修改优化器

在pytorch下,当使用fp16时,需要修改optimizer。类似代码如下(代码参考这里):

  1. # Prepare optimizer
  2. if args.do_train:
  3. param_optimizer = list(model.named_parameters())
  4. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  5. optimizer_grouped_parameters = [
  6. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  7. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  8. ]
  9. if args.fp16:
  10. try:
  11. from apex.optimizers import FP16_Optimizer
  12. from apex.optimizers import FusedAdam
  13. except ImportError:
  14. raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
  15. optimizer = FusedAdam(optimizer_grouped_parameters,
  16. lr=args.learning_rate,
  17. bias_correction=False,
  18. max_grad_norm=1.0)
  19. if args.loss_scale == 0:
  20. optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
  21. else:
  22. optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
  23. warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
  24. t_total=num_train_optimization_steps)
  25. else:
  26. optimizer = BertAdam(optimizer_grouped_parameters,
  27. lr=args.learning_rate,
  28. warmup=args.warmup_proportion,
  29. t_total=num_train_optimization_steps)

第三:backward时做对应修改

  1. if args.fp16:
  2. optimizer.backward(loss)
  3. else:
  4. loss.backward()

第四:学习率修改

  1. if args.fp16:
  2. # modify learning rate with special warm up BERT uses
  3. # if args.fp16 is False, BertAdam is used that handles this automatically
  4. lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
  5. for param_group in optimizer.param_groups:
  6. param_group['lr'] = lr_this_step
  7. optimizer.step()
  8. optimizer.zero_grad()

根据参考3,值得重述一些重要结论:

(1)深度学习训练使用16bit表示/运算正逐渐成为主流。

(2)低精度带来了性能、功耗优势,但需要解决量化误差(溢出、舍入)。

(3)常见的避免量化误差的方法:为权重保持高精度(fp32)备份;损失放大,避免梯度的下溢出;一些特殊层(如BatchNorm)仍使用fp32运算。

参考资料:

1.nv官方repo给了一些基于pytorch的apex加速的实现

实现是基于fairseq实现的,可以直接对比代码1-apex版代码2-非apex版(fairseq官方版),了解是如何基于apex实现加速的。

2.nv官方关于混合精度优化的原理介绍

按图索骥,可以get到很多更加具体地内容。

3.低精度表示用于深度学习 训练与推断

感谢团队同学推荐。

[Pytorch]基于混和精度的模型加速的更多相关文章

  1. 【神经网络篇】--基于数据集cifa10的经典模型实例

    一.前述 本文分享一篇基于数据集cifa10的经典模型架构和代码. 二.代码 import tensorflow as tf import numpy as np import math import ...

  2. 基于MATLAB搭建的DDS模型

    基于MATLAB搭建的DDS模型 说明: 累加器输出ufix_16_6数据,通过cast切除小数部分,在累加的过程中,带小数进行运算最后对结果进行处理,这样提高了计算精度. 关于ROM的使用: 直接设 ...

  3. StartDT AI Lab | 视觉智能引擎之算法模型加速

    通过StartDT AI Lab专栏之前多篇文章叙述,相信大家已经对计算机视觉技术及人工智能算法在奇点云AIOT战略中的支撑作用有了很好的理解.同样,这种业务牵引,技术覆盖的模式也收获了市场的良好反响 ...

  4. Atitit  基于meta的orm,提升加速数据库相关应用的开发

    Atitit  基于meta的orm,提升加速数据库相关应用的开发 1.1. Overview概论1 1.2. Function & Feature功能特性1 1.2.1. meta api2 ...

  5. 基于git的源代码管理模型——git flow

    基于git的源代码管理模型--git flow A successful Git branching model

  6. 详解Linux2.6内核中基于platform机制的驱动模型 (经典)

    [摘要]本文以Linux 2.6.25 内核为例,分析了基于platform总线的驱动模型.首先介绍了Platform总线的基本概念,接着介绍了platform device和platform dri ...

  7. 基于R语言的ARIMA模型

    A IMA模型是一种著名的时间序列预测方法,主要是指将非平稳时间序列转化为平稳时间序列,然后将因变量仅对它的滞后值以及随机误差项的现值和滞后值进行回归所建立的模型.ARIMA模型根据原序列是否平稳以及 ...

  8. 基于PaddlePaddle的语义匹配模型DAM,让聊天机器人实现完美回复 |

    来源商业新知网,原标题:让聊天机器人完美回复 | 基于PaddlePaddle的语义匹配模型DAM 语义匹配 语义匹配是NLP的一项重要应用.无论是问答系统.对话系统还是智能客服,都可以认为是问题和回 ...

  9. 第13章 TCP编程(4)_基于自定义协议的多线程模型

    7. 基于自定义协议的多线程模型 (1)服务端编程 ①主线程负责调用accept与客户端连接 ②当接受客户端连接后,创建子线程来服务客户端,以处理多客户端的并发访问. ③服务端接到的客户端信息后,回显 ...

随机推荐

  1. windows 操作系统下git报filename too long 处理方法

    两种方法解决: 一是通过修改配置文件 [core] repositoryformatversion = filemode = false bare = false logallrefupdates = ...

  2. PAI-STUDIO通过Tensorflow处理MaxCompute表数据

    PAI-STUDIO在支持OSS数据源的基础上,增加了对MaxCompute表的数据支持.用户可以直接使用PAI-STUDIO的Tensorflow组件读写MaxCompute数据,本教程将提供完整数 ...

  3. 使用新版本5+SDK创建最简Android原生工程(Android studio)http://ask.dcloud.net.cn/article/13232

    1 使用Android Studio创建一个工程 2 删除原生工程中Java目录下系统默认创建的源代码 3 复制SDK->libs->lib.5plus.base-release.aar文 ...

  4. Android消息机制使用注意事项,防止泄漏

    在Android的线程通信当中,使用频率最多的就是Android的消息处理机制(Handler.send().View.post().Asynctask.excute()等等都使用到了消息处理机制). ...

  5. 注解1 --- JDK内置的三个基本注解 --- 技术搬运工(尚硅谷)

    @Override: 限定重写父类方法, 该注解只能用于方法 @Deprecated: 用于表示所修饰的元素(类, 方法等)已过时.通常是因为所修饰的结构危险或存在更好的选择 @SuppressWar ...

  6. python之高阶函数--map()和reduce()

    以下为学习笔记:来自廖雪峰的官方网站 1.高阶函数:简单来说是一个函数里面嵌入另一个函数 2.python内建的了map()和reduce()函数 map()函数接收两参数,一个是函数,一个是Iter ...

  7. jQuery中的工具和插件

    jQuery的工具属性 jQuery类数组操作 length属性 表示获取类数组中元素的个数 get()方法 表示获取类数组中单个元素"括号中填写该元素的索引值" index()方 ...

  8. BZOJ 4420二重镇题解

    链接 思路借鉴了这个博客: 我们可以想到状压dp 用一个十进制数来表示状态,即第i位表示位置i处的物品等级 用f[i][j][k]表示第i天,仓库的物品等级为j,状态为k时的最大收益 但是状态数貌似很 ...

  9. qt 鼠标拖动窗口放大缩小

    // 鼠标拖动 具体实现void mouseMoveEvent(QMouseEvent * pEvent) { if (pEvent->buttons() & Qt::LeftButto ...

  10. input的相关兼容性问题

    近来在制作登陆页的input文本框和密码框的时候,具体的实例可参考实现带样式的表单验证,我们发现在IE下默认的情况下,input 标签的密码框和文本框宽度不一致,这就尴尬了. 解决这个办法,我们是直接 ...