Megengine量化

量化指的是将浮点数模型(一般是32位浮点数)的权重或激活值用位数更少的数值类型(比如8位整数、16位浮点数)来近似表示的过程。 量化后的模型会占用更小的存储空间,还能够利用许多硬件平台上的专属算子进行提速。比如在 MegEngine 中使用8位整数来进行量化,相比默认的32位浮点数,模型大小可以减少为1/4,而运行在特定的设备上其计算速度也能提升为2-4倍。

量化的目的是为了追求极致的推理计算速度,为此舍弃了数值表示的精度,直觉上会带来较大的模型掉点,但是在使用一系列精细的量化处理之后,其掉点可以变得微乎其微,并能支持正常的部署使用。而且近年来随着专用神经网络加速芯片的兴起,低比特非浮点的运算方式越来越普及,因此如何把一个 GPU 上训练的浮点数模型转化为低比特的量化模型,就成为了工业界非常关心的话题。

一般来说,得到量化模型的转换过程按代价从低到高可以分为以下4种:

图1. 量化转换过程分类

  • Type1 和 Type2 由于是在模型浮点模型训练之后介入,无需大量训练数据,故而转换代价更低,被称为后量化(Post Quantization);
  • Type3 和 Type4 则需要在浮点模型训练时就插入一些假量化(FakeQuantize)算子,模拟计算过程中数值截断后精度降低的情形,故而称为量化感知训练(Quantization Aware Training, QAT)。

本文主要介绍 Type2 和 Type3 在 MegEngine 中的完整流程,事实上,除了 Type2 无需进行假量化,两者的整体流程完全一致。

整体流程

以 Type3 为例,一般以一个训练完毕的浮点模型为起点,称为 Float 模型。包含假量化算子的用浮点操作来模拟量化过程的新模型,称为 Quantized-Float 模型,或者 QFloat 模型。可以直接在终端设备上运行的模型,称为 Quantized 模型,简称 Q 模型。

而三者的精度一般是 Float > QFloat > Q ,故而一般量化算法也就分为两步:

  • 拉近 QFloat 和 Q,这样训练阶段的精度可以作为最终 Q 精度的代理指标,这一阶段偏工程;
  • 拔高 QFloat 逼近 Float,这样就可以将量化模型性能尽可能恢复到 Float 的精度,这一阶段偏算法。

典型的三种模型在三个阶段的精度变化如下:

图2. 三阶段模型的精度变化

对应到具体的 MegEngine 接口中,三阶段如下:

  1. 基于 Module 搭建网络模型,并按照正常的浮点模型方式进行训练;
  2. 使用 quantize_qat() 将浮点模型转换为 QFloat 模型,其中可被量化的关键 Module 会被转换为 QATModule ,并基于量化配置 QConfig 设置好假量化算子和数值统计方式;
  3. 使用 quantize() 将 QFloat 模型转换为 Q 模型,对应的 QATModule 则会被转换为 QuantizedModule ,此时网络无法再进行训练,网络中的算子都会转换为低比特计算方式,即可用于部署了。

该流程是 Type3 对应 QAT 的步骤,Type2 对应的后量化则需使用不同 QConfig,且需使用 evaluation 模式运行 QFloat 模型,而非训练模式。更多细节可以继续阅读下一节详细的接口介绍。

接口介绍

在 MegEngine 中,最上层的接口是配置如何量化的 QConfig 和模型转换模块里的 quantize_qat() 与 quantize() 。

QConfig

QConfig 包括了 Observer 和 FakeQuantize 两部分。知道,对模型转换为低比特量化模型一般分为两步:一是统计待量化模型中参数和 activation 的数值范围(scale)和零点(zero_point),二是根据 scale 和 zero_point 将模型转换成指定的数值类型。而为了统计这两个值,需要使用 Observer。

Observer 继承自 Module ,也会参与网络的前向传播,但是其 forward 的返回值就是输入,所以不会影响网络的反向梯度传播。其作用就是在前向时拿到输入的值,并统计其数值范围,并通过 get_qparams() 来获取。所以在搭建网络时把需要统计数值范围的的 Tensor 作为 Observer 的输入即可。

# forward of MinMaxObserver

def forward(self, x_orig):

if self.enabled:

# stop gradient

x = x_orig.detach()

# find max and min

self.min_val._reset(F.minimum(self.min_val, x.min()))

self.max_val._reset(F.maximum(self.max_val, x.max()))

return x_orig

如果只观察而不模拟量化会导致模型掉点,于是需要有 FakeQuantize 来根据 Observer 观察到的数值范围模拟量化时的截断,使得参数在训练时就能提前“适应“这种操作。FakeQuantize 在前向时会根据传入的 scale 和 zero_point 对输入 Tensor 做模拟量化的操作,即先做一遍数值转换再转换后的值还原成原类型,如下所示:

def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:

scale = q_dict["scale"]

zero_point = 0

if q_dict["mode"] == QuantMode.ASYMMERTIC:

zero_point = q_dict["zero_point"]

# Quant

oup = Round()(inp / scale) + zero_point

# Clip

oup = F.minimum(F.maximum(oup, qmin), qmax)

# Dequant

oup = (oup - zero_point) * scale

return oup

目前 MegEngine 支持对 weight/activation 两部分的量化,如下所示:

ema_fakequant_qconfig = QConfig(

weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),

act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False),

weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),

act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),

)

这里使用了两种 Observer 来统计信息,而 FakeQuantize 使用了默认的算子。

如果是后量化,或者说 Calibration,由于无需进行 FakeQuantize,故而其 fake_quant 属性为 None 即可:

calibration_qconfig = QConfig(

weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),

act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),

weight_fake_quant=None,

act_fake_quant=None,

)

除了使用在 megengine.quantization.qconfig 里提供的预设 QConfig,也可以根据需要灵活选择 Observer 和 FakeQuantize 实现自己的 QConfig。目前提供的 Observer 包括:

  • MinMaxObserver ,使用最简单的算法统计 min/max,对见到的每批数据取 min/max 跟当前存的值比较并替换,基于 min/max 得到 scale 和 zero_point;
  • ExponentialMovingAverageObserver ,引入动量的概念,对每批数据的 min/max 与现有 min/max 的加权和跟现有值比较;
  • HistogramObserver ,更加复杂的基于直方图分布的 min/max 统计算法,且在 forward 时持续更新该分布,并根据该分布计算得到 scale 和 zero_point。

对于 FakeQuantize,目前还提供了 TQT 算子,另外还可以继承 _FakeQuant 基类实现自定义的假量化算子。

在实际使用过程中,可能需要在训练时让 Observer 统计并更新参数,但是在推理时则停止更新。 Observer 和 FakeQuantize 都支持 enable() 和 disable() 功能,且 Observer 会在 train() 和 train() 时自动分别调用 enable/disable。

所以一般在 Calibration 时,会先执行 net.eval() 保证网络的参数不被更新,然后再执行 enable_observer(net) 来手动开启 Observer 的统计修改功能。

模型转换模块与相关基类

QConfig 提供了一系列如何对模型做量化的接口,而要使用这些接口,需要网络的 Module 能够在 forward 时给参数、activation 加上 Observer 和进行 FakeQuantize。转换模块的作用就是将模型中的普通 Module 替换为支持这一系列操作的 QATModule ,并能支持进一步替换成无法训练、专用于部署的 QuantizedModule 。

基于三种基类实现的 Module 是一一对应的关系,通过转换接口可以依次替换为不同实现的同名 Module。同时考虑到量化与算子融合(Fuse)的高度关联,提供了一系列预先融合好的 Module,比如 ConvRelu2d 、 ConvBn2d 和 ConvBnRelu2d 等。除此之外还提供专用于量化的 QuantStub 、 DequantStub 等辅助模块。

转换的原理很简单,就是将父 Module 中可被量化(Quantable)的子 Module 替换为对应的新 Module。但是有一些 Quantable Module 还包含 Quantable 子 Module,比如 ConvBn 就包含一个 Conv2d 和一个 BatchNorm2d,转换过程并不会对这些子 Module 进一步转换,原因是父 Module 被替换之后,其 forward 计算过程已经完全不同了,不会再依赖于这些子 Module。

Note

如果需要使一部分 Module 及其子 Module 保留 Float 状态,不进行转换,可以使用 disable_quantize() 来处理。

如果网络结构中涉及一些二元及以上的 ElementWise 操作符,比如加法乘法等,由于多个输入各自的 scale 并不一致,必须使用量化专用的算子,并指定好输出的 scale。实际使用中只需要把这些操作替换为 Elemwise 即可,比如 self.add_relu = Elemwise("FUSE_ADD_RELU")

由于转换过程修改了原网络结构, 网络的训练和测试 中提到的模型保存与加载无法直接适用于转换后的网络,读取新网络保存的参数时,需要先调用转换接口得到转换后的网络,才能用 load_state_dict 将参数进行加载。

实例讲解

下面以 ResNet18 为例来讲解量化的完整流程,完整代码见 MegEngine Models 。主要分为以下几步:

  1. 修改网络结构,使用已经 Fuse 好的 ConvBn2d、ConvBnRelu2d、ElementWise 代替原先的 Module;
  2. 在正常模式下预训练模型,并在每轮迭代保存网络检查点;
  3. 调用 quantize_qat() 转换模型,并进行 finetune;
  4. 调用 quantize() 转换为量化模型,并执行 dump 用于后续模型部署。

网络结构见 resnet.py ,相比惯常写法,修改了其中一些子 Module,将原先单独的 conv, bn, relu 替换为 Fuse 过的 Quantable Module。

class BasicBlock(Module):

def __init__(self, in_planes, planes, stride=1):

super(BasicBlock, self).__init__()

self.conv_bn_relu = ConvBnRelu2d(

in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False

)

self.conv_bn = ConvBn2d(

planes, planes, kernel_size=3, stride=1, padding=1, bias=False

)

self.add_relu = Elemwise("FUSE_ADD_RELU")

self.shortcut = Sequential()

if stride != 1 or in_planes != planes:

self.shortcut = Sequential(

ConvBn2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)

)

def forward(self, x):

out = self.conv_bn_relu(x)

out = self.conv_bn(out)

cut = self.shortcut(x)

out = self.add_relu(out, cut)

return out

然后对该模型进行若干轮迭代训练,并保存检查点,这里省略细节:

for step in range(0, total_steps):

# Linear learning rate decay

epoch = step // steps_per_epoch

learning_rate = adjust_learning_rate(step, epoch)

image, label = next(train_queue)

image = tensor(image.astype("float32"))

label = tensor(label.astype("int32"))

n = image.shape[0]

loss, acc1, acc5 = train_func(image, label, net, gm)

optimizer.step()

optimizer.clear_grad()

调用 quantize_qat() 来将网络转换为 QATModule:

from megengine.quantization import ema_fakequant_qconfig

from megengine.quantization.quantize import quantize_qat

model = ResNet18()

if args.mode != "normal":

quantize_qat(model, ema_fakequant_qconfig)

使用默认的 ema_fakequant_qconfig 来进行 int8 量化。

然后继续使用上面相同的代码进行 finetune 训练。值得注意的是,如果这两步全在一次程序运行中执行,那么训练的 trace 函数需要用不一样的,因为模型的参数变化了,需要重新进行编译。示例代码中则是采用在新的执行中读取检查点重新编译的方法。

在 QAT 模式训练完成后,继续保存检查点,执行 inference.py 并设置 mode 为 quantized ,这里需要将原始 Float 模型转换为 QAT 模型之后再加载检查点。

from megengine.quantization.quantize import quantize_qat

model = ResNet18()

if args.mode != "normal":

quantize_qat(model, ema_fakequant_qconfig)

if args.checkpoint:

logger.info("Load pretrained weights from %s", args.checkpoint)

ckpt = mge.load(args.checkpoint)

ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt

model.load_state_dict(ckpt, strict=False)

模型转换为量化模型包括以下几步:

from megengine.quantization.quantize import quantize

# 定义trace函数,打开capture_as_const以进行dump

@jit.trace(capture_as_const=True)

def infer_func(processed_img):

model.eval()

logits = model(processed_img)

probs = F.softmax(logits)

return probs

# 执行模型转换

if args.mode == "quantized":

quantize(model)

# 准备数据

processed_img = transform.apply(image)[np.newaxis, :]

if args.mode == "normal":

processed_img = processed_img.astype("float32")

elif args.mode == "quantized":

processed_img = processed_img.astype("int8")

# 执行一遍evaluation

probs = infer_func(processed_img)

# 将模型 dump 导出

infer_func.dump(output_file, arg_names=["data"])

至此,便得到了一个可用于部署的量化模型。

Megengine量化的更多相关文章

  1. deeplearning模型量化实战

    deeplearning模型量化实战 MegEngine 提供从训练到部署完整的量化支持,包括量化感知训练以及训练后量化,凭借"训练推理一体"的特性,MegEngine更能保证量化 ...

  2. 旷视MegEngine核心技术升级

    旷视MegEngine核心技术升级 7 月 11 日,旷视研究院在 2020 WAIC · 开发者日「深度学习框架与技术生态论坛」上围绕 6 月底发布的天元深度学习框架(MegEngine)Beta ...

  3. MegEngine推理性能优化

    MegEngine推理性能优化 MegEngine「训练推理一体化」的独特范式,通过静态图优化保证模型精度与训练时一致,无缝导入推理侧,再借助工业验证的高效卷积优化技术,打造深度学习推理侧极致加速方案 ...

  4. 如何设计一个高内聚低耦合的模块——MegEngine 中自定义 Op 系统的实践经验

    作者:褚超群 | 旷视科技 MegEngine 架构师 背景介绍 在算法研究的过程中,算法同学们可能经常会尝试定义各种新的神经网络层(neural network layer),比如 Layer No ...

  5. 【模型推理】Tengine 模型转换及量化

      欢迎关注我的公众号 [极智视界],回复001获取Google编程规范   O_o   >_<   o_O   O_o   ~_~   o_O   本文介绍一下 Tengine 模型转换 ...

  6. Atitit  图像处理Depixelizing Pixel Art像素风格画的矢量化

    Atitit  图像处理Depixelizing Pixel Art像素风格画的矢量化 在去年的时候,偶然看到hqx算法. 一个高质量的插值放大算法. 与双线性插值等插值算法相比,这个算法放大后对人眼 ...

  7. 《量化投资:以MATLAB为工具》连载(2)基础篇-N分钟学会MATLAB(中)

    http://www.matlabsky.com/thread-43937-1-1.html   <量化投资:以MATLAB为工具>连载(3)基础篇-N分钟学会MATLAB(下)     ...

  8. 《量化投资:以MATLAB为工具》连载(1)基础篇-N分钟学会MATLAB(上)

    http://blog.sina.com.cn/s/blog_4cf8aad30102uylf.html <量化投资:以MATLAB为工具>连载(1)基础篇-N分钟学会MATLAB(上) ...

  9. 矢量化的HTML5拓扑图形组件设计

    HT一直被客户称道的就是其全矢量化的设计特色,矢量相比传统图片好处太多了: www.hightopo.com/guide/guide/core/vector/ht-vector-guide.html ...

随机推荐

  1. 服务器安装node全教程

    我的服务器centos,安装node时出了点小麻烦,在这里记述我的方法. 1.进入node下载网站https://nodejs.org/en/download/,这里右键复制下载链接 2.进入cent ...

  2. input.focus()在IOS上失效的解决方法

    之前在iphone上做开发时遇到一个问题,在一般的正常浏览器上输入以下代码: 1 2 var apple = document.getElementById('abc'); apple.focus() ...

  3. 《机器学习Python实现_10_09_集成学习_bagging_stacking原理及实现》

    介绍 前面对模型的组合主要用了两种方式: (1)一种是平均/投票: (2)另外一种是加权平均/投票: 所以,我们有时就会陷入纠结,是平均的好,还是加权的好,那如果是加权,权重又该如何分配的好?如果我们 ...

  4. Linux中Tomcat和Jboss的安装和部署

    目录 JDK环境 yum源安装JDK 源码包安装JDK Tomcat的安装 yum源安装 目录结构: 源码包安装 目录结构: 目录中主要的文件: JBoss的安装 目录结构: Tomcat是Apach ...

  5. Windows PE资源表编程(枚举资源树)

    资源枚举 写一个例子,枚举一个PE文件的资源表.首先说下资源相关的作为铺垫. 1.资源类型也是PE可选头中数据目录的一种.位于第三个类型. 2.资源目录分为三层.第四层是描述文件相关的.这些结构是按照 ...

  6. 【python】Leetcode每日一题-笨阶乘

    [python]Leetcode每日一题-笨阶乘 [题目描述] 通常,正整数 n 的阶乘是所有小于或等于 n 的正整数的乘积.例如,factorial(10) = 10 * 9 * 8 * 7 * 6 ...

  7. SQLyog连接数据库报错 plugin caching_sha2_password could not be loaded

    错误如图所示: 问题描述: 下载新版的 mysql 8.0.11 安装. 为了方便安装查看,我下载了sqlyog 工具 连接 mysql. 配置新连接报错:错误号码 2058 问题分析: mysql ...

  8. jenkins 下使用ansible 跨服务器控制操作

    例如: A服务器地址:172.16.1.203 B服务器地址:172.16.1.204 当jenkins 在A 服务器并且用户aa,  控制B 服务器的用户bb的操作 (1)B服务器 用ssh-key ...

  9. Visual Lab Online —— Beta版本发布声明

    项目 内容 班级:北航2020春软件工程 博客园班级博客 作业:Beta阶段发布声明 发布声明 目录 发布方式.发布地址与运行环境要求 软件主体 浏览器扩展 Beta版本新功能 登录注册页 注册时邮箱 ...

  10. Spring Boot 允许跨域设置失败的问题深究

    在公司开发过程中,一个前后端分离的项目遇见了跨域的问题. 前端控制台报错:No 'Access-Control-Allow-Origin' header is present on the reque ...