前置知识

混合精度训练

在参数存储时采取fp32, 开始进行fp/bp时转成fp16运算, 拿到fp16梯度后再转回fp32更新参数.

ZeRO对显存占用的估算:

  • 模型状态: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型参数备份),momentum(fp32)和variance(fp32)。假设模型参数量 \(\phi\) ,则共需要\(2\Phi + 2\Phi + (4\Phi + 4\Phi + 4\Phi) = 4\Phi + 12\Phi = 16\Phi\) 字节存储,
  • 剩余状态: 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)

Adam

在adam optimizer的计算状态除了参数, 还有一个\(m_t\)(momentum 梯度均值)和\(v_t\)(variance 梯度未中心化方差)需要存储, 一般被称为optimizer state.

AllToAll通信原语

allToall类似于矩阵转置. 相当于我们需要先把每个节点里的数据按照他们要传递给哪个节点排好序, 然后根据切分好的顺序推给对应的节点. 可以看到如果每个节点的数据量是M, 节点数是N, 最终通信总量就是M * N

ZeRO

在传统的训练方法里, 每张卡里存储一份完整的模型状态, 完成bp后allReduce grad,再更新每张卡里的副本. 这样子有N张卡就会多出(N-1)份冗余的参数存储. 当参数规模急剧增大时这种方法就完全不适合训练. ZeRO1 主要是将这些冗余的模型状态干掉, 通过增加通信来解决冗余参数的问题. ZeRO原理动态图

  • ZeRO1: 只保留一份MasterWeights+momentum+variance.
  • ZeRO2: 在ZeRO1的基础上去除了grad的冗余
  • ZeRO3: 在ZeRO2的基础上去掉了weights的冗余

训练流程

以ZeRO3为例. 主要分为5步, 假设使用了4张卡进行训练:

  1. 每张卡上存1/4的W, OS和grad. 每张卡训练自己分配到的batch.
  2. fp时, AllGather所有卡上的W,取到全量的W(fp16)进行fp, 完成后只保留自己需要维护的1/4 W, 其他显存释放回池
  3. bp时, AllGather所有卡上的W进行bp, 完成后再抛弃其他卡维护的W
  4. 完成bp后, ReduceScatter所有卡的G, 从其他卡上取到需要需要更新的梯度增量, 然后释放不是自己维护的G.
  5. 使用自己维护的OS和G来更新W, 不需要通信.


通信量分析

定义单卡数据量为\(\phi\)

传统DP: bp完成后需要对梯度进行一次AllReduce, 一共\(2\phi\)

ZeRO1: 只舍弃了OS, bp时需要AllReduce G(Scatter+Gather 共\(2\phi\)). 另外在使用每张卡各自更新W时, 因为W每张卡都存储的全量, 需要从存储OS的卡上把对应更新后的W再拉回来, 所以需要一次Gather(\(\phi\)), 一共需要\(3\phi\)

ZeRO2: 舍弃了OS和G, bp时AllGather G(\(\phi\)), 更新W时从其他卡拉W, 再Gather一次(\(\phi\)), 一共需要\(2\phi\)

ZeRO3: 上面训练过程分析过, 共需要2次Gather和1次Scatter, 一共需要\(3\phi\)

可以看到ZeRO在通信量只增加了1.5倍的情况下, 显存降了60倍. 效果非常显著

ZeRO++

ZeRO存在的问题是会在GPU之间产生大量数据传输开销,降低了训练效率. 主要有两种情况:

  1. 全局batch size较小,而 GPU数量多,这导致每个 GPU 上batch size较小,需要频繁通信

  2. 在低端集群上进行训练,其中跨节点网络带宽有限,导致高通信延迟。

ZeRO++主要采用了3部分优化: 权重量化 (qwZ), 分层分割存储 (hpZ), 梯度量化 (qgZ). 对比ZeRO通信量减少了4倍, 主要的难点都在减小量化带来的训练误差

权重量化

  1. def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  2. q_range = 2**self.config['num_bits'] - 1
  3. min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)
  4. max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True)
  5. scale = q_range / (max_value - min_value)
  6. tensor = tensor.sub_(min_value).mul_(scale)
  7. tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8) #对称式量化
  8. return tensor, scale, min_value

量化kernel在deepspeedcsrc/quantization/quantize.cu cached_quantization 这个kernel里.

如果采用全局fp16->int8的量化会导致极大误差. deepspeed采用了分区量化的方法, 把参数分为固定大小的block后, 先根据这个block的max/min计算出scale(量化系数), 在把这个参数传入量化函数里. 另外在通信的时候应该也需要每个block对应的系数传给接收节点用于反量化.

\[量化公式: clip(round(scale * x), -2^{b-1}+1, 2^{b-1}-1)
\]

通过这种方式在通信量减半的同时还能保证精度, 很nice的思路.

分层分割存储

之前ZeRO的W切分方法是根据卡数均分. 在fp/bp之前进行AllGather拉取, 后来发现在机器间进行Gather通信是比较严重的瓶颈. 所以最后W的切分变成了每个节点内存储全量的W, 节点内根据卡数进行切片. 避免跨节点经过网卡的通信, 通过增加显存使用的方式解决通信瓶颈.

显存消耗: ZeRO3的单卡显存消耗为 $\frac{(2+2+K)*\phi}{N} \(, 这里每个节点多存了一份W, 如果有\)\alpha$个物理节点, 那么每张卡使用的显存就多了 \(\frac{\alpha * \phi}{N}\)

梯度量化

如果直接在之前zero RingAllReduce的通信方式上加量化和反量化, 如下图左, 可以看到需要节点个数次量化/反量化. 而每次量化都是有损的, 这样会导致无法接受的训练误差. 为了解决这个问题zero++使用了一次量化->AllToAll通信->一次反量化的操作. 而因为直接进行AllToAll通信量从M(参数量)变成了M*N/Z(N: 节点数, Z:量化压缩率), 这个通信量的增长过大. deepspeed设计了2-hpop all-to-all方法来解决通信问题.

具体图示流程可以参考Deepspeed的blog动态图, 文字版步骤:

  1. 节点内的卡间张量切片重排. 主要是因为alltoall切分成了两步, 如果不重排如下图左. 最后顺序会变错位, 然后进行参数量化

  2. 节点内alltoall通信后反量化.先把卡内能合并的梯度加起来. 这里反量化主要是为了减小梯度累加的精度损失

  3. 再次量化后, 节点间进行allToAll

  4. 拿到通信结果, 反量化后再次reduce. 得到最终的梯度.

这里要进行两次alltoall的原因主要是, 第一次卡间alltoall之后梯度累加可以减少卡数倍的通信规模. 实际deepspeed在实现的时候还把重分片和量化kernel进行了fuse, 进一步优化性能

还有下图的方法, 在通信当前层的时候, 通过多流异步量化下一层要通信的数据. 避免同步等待的浪费

参考

zero: https://arxiv.org/pdf/1910.02054

混合精度训练: https://arxiv.org/pdf/1710.03740

zero++: https://arxiv.org/abs/2306.10209

Deepspeed blog: https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md

LLM并行训练3-数据并行的更多相关文章

  1. C#并行编程之数据并行

    所谓的数据并行的条件是: 1.拥有大量的数据. 2.对数据的逻辑操作都是一致的. 3.数据之间没有顺序依赖. 运行并行编程可以充分的利用现在多核计算机的优势.记录代码如下: public class ...

  2. C#并行编程--命令式数据并行(Parallel.Invoke)---与匿名函数一起理解(转载整理)

    命令式数据并行   Visual C# 2010和.NETFramework4.0提供了很多令人激动的新特性,这些特性是为应对多核处理器和多处理器的复杂性设计的.然而,因为他们包括了完整的新的特性,开 ...

  3. C#并行编程--命令式数据并行(Parallel.Invoke)

    命令式数据并行   Visual C# 2010和.NETFramework4.0提供了很多令人激动的新特性,这些特性是为应对多核处理器和多处理器的复杂性设计的.然而,因为他们包括了完整的新的特性,开 ...

  4. ML2021 | (腾讯)PatrickStar:通过基于块的内存管理实现预训练模型的并行训练

    ​  前言  目前比较常见的并行训练是数据并行,这是基于模型能够在一个GPU上存储的前提,而当这个前提无法满足时,则需要将模型放在多个GPU上.现有的一些模型并行方案仍存在许多问题,本文提出了一种名为 ...

  5. TensorFlow分布式计算机制解读:以数据并行为重

    Tensorflow 是一个为数值计算(最常见的是训练神经网络)设计的流行开源库.在这个框架中,计算流程通过数据流程图(data flow graph)设计,这为更改操作结构与安置提供了很大灵活性.T ...

  6. SIMD数据并行(三)——图形处理单元(GPU)

    在计算机体系中,数据并行有两种实现路径:MIMD(Multiple Instruction Multiple Data,多指令流多数据流)和SIMD(Single Instruction Multip ...

  7. 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。

    百度为何开源深度机器学习平台?   有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举.   5月20日,百度在github上开源了其 ...

  8. PyTorch如何加速数据并行训练?分布式秘籍大揭秘

    PyTorch 在学术圈里已经成为最为流行的深度学习框架,如何在使用 PyTorch 时实现高效的并行化? 在芯片性能提升有限的今天,分布式训练成为了应对超大规模数据集和模型的主要方法.本文将向你介绍 ...

  9. 深度神经网络DNN的多GPU数据并行框架 及其在语音识别的应用

    深度神经网络(Deep Neural Networks, 简称DNN)是近年来机器学习领域中的研究热点,产生了广泛的应用.DNN具有深层结构.数千万参数需要学习,导致训练非常耗时.GPU有强大的计算能 ...

  10. 【深度学习系列2】Mariana DNN多GPU数据并行框架

    [深度学习系列2]Mariana DNN多GPU数据并行框架  本文是腾讯深度学习系列文章的第二篇,聚焦于腾讯深度学习平台Mariana中深度神经网络DNN的多GPU数据并行框架.   深度神经网络( ...

随机推荐

  1. NASM中的ALIGN ALIGNB SECTALIGN

    ALIGN与ALIGNB NASM中的ALIGN与ALIGNB是用来字节对齐的,它们接收2个参数,第一个参数是必须的,表示对齐的字节数(必须是2的幂),第二个参数是可选的,表示为了对齐而进行填充的内容 ...

  2. NASM中的伪指令

    伪指令不是真正的指令,而是为了方便NASM汇编器而存在,但是它们的地位与真正的指令相同: label: instruction operands ; comment instruction部分就可以是 ...

  3. 对于Docker和Podman的一点使用经验

    前言:本文会以多个实际的线上例子,分享自己对于Docker和Podman的一点使用经验及踩过的坑,希望对读者有一点帮助. 本文bash脚本初步加工后可直接使用(兼容mac和linux系统),对于关键点 ...

  4. 数据转换2-无人机航拍倾斜摄影转换成OSGB格式

    首先软件的下载和安装参考下面链接 http://www.xue51.com/soft/53013.html 0.首先打开软件,要打开2个哦. 打数据处理开后台 ContextCapture Engin ...

  5. blocks (单调栈)

    题目描述 解析 对于这道题,他要求大于k的数进行操作,所以直接让每个数减k,然后用前缀和维护一下与0比较就可以了,因为一段区间和的平 均值大于k的话,那么这就是一个合法区间,即为操作后的这个区间和大于 ...

  6. 助力抗疫 Splashtop 远程控制软件限时免费

    近期国内疫情又有抬头趋势,给我们的工作.生活带来诸多不便.面对疫情,居家办公是一个兼顾安全健康和保持生产力的好办法.据了解,很多广州的企业现在已经在关注或开始部署远程办公方案. 为了帮助疫情中高风险地 ...

  7. Python基础篇(流程控制)

    流程控制是程序运行的基础,流程控制决定了程序按照什么样的方式执行. 条件语句 条件语句一般用来判断给定的条件是否成立,根据结果来执行不同的代码,也就是说,有了条件语句,才可以根据不同的情况做不同的事, ...

  8. pageoffice6 版本实现word 文件添加水印

    在很多场景下,Word文档正式发文之前,或者说形成最终文档之前,常常需要往Word文件中添加水印,并且会根据文件类型或内容的不同,需要添加的水印也不一样. 添加水印是Word软件里的一个简单功能,直接 ...

  9. QShop商城--项目介绍

    QShop商城-项目介绍 QShop商城,是全新推出的一款轻量级.高性能.前后端分离的电商系统,支持微信小程序,前后端源码100%开源,完美支持二次开发,让您快速搭建个性化独立商城. 技术架构:.Ne ...

  10. linux wget命令的重要用法:下载文件并保存,后台下载

    Linux wget命令是一个下载文件的工具,它用在命令行下. #从网络下载一个文件并保存在当前目录 [root@node5 ~]# wget http://cn.wordpress.org/word ...