前置知识

混合精度训练

在参数存储时采取fp32, 开始进行fp/bp时转成fp16运算, 拿到fp16梯度后再转回fp32更新参数.

ZeRO对显存占用的估算:

  • 模型状态: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型参数备份),momentum(fp32)和variance(fp32)。假设模型参数量 \(\phi\) ,则共需要\(2\Phi + 2\Phi + (4\Phi + 4\Phi + 4\Phi) = 4\Phi + 12\Phi = 16\Phi\) 字节存储,
  • 剩余状态: 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)

Adam

在adam optimizer的计算状态除了参数, 还有一个\(m_t\)(momentum 梯度均值)和\(v_t\)(variance 梯度未中心化方差)需要存储, 一般被称为optimizer state.

AllToAll通信原语

allToall类似于矩阵转置. 相当于我们需要先把每个节点里的数据按照他们要传递给哪个节点排好序, 然后根据切分好的顺序推给对应的节点. 可以看到如果每个节点的数据量是M, 节点数是N, 最终通信总量就是M * N

ZeRO

在传统的训练方法里, 每张卡里存储一份完整的模型状态, 完成bp后allReduce grad,再更新每张卡里的副本. 这样子有N张卡就会多出(N-1)份冗余的参数存储. 当参数规模急剧增大时这种方法就完全不适合训练. ZeRO1 主要是将这些冗余的模型状态干掉, 通过增加通信来解决冗余参数的问题. ZeRO原理动态图

  • ZeRO1: 只保留一份MasterWeights+momentum+variance.
  • ZeRO2: 在ZeRO1的基础上去除了grad的冗余
  • ZeRO3: 在ZeRO2的基础上去掉了weights的冗余

训练流程

以ZeRO3为例. 主要分为5步, 假设使用了4张卡进行训练:

  1. 每张卡上存1/4的W, OS和grad. 每张卡训练自己分配到的batch.
  2. fp时, AllGather所有卡上的W,取到全量的W(fp16)进行fp, 完成后只保留自己需要维护的1/4 W, 其他显存释放回池
  3. bp时, AllGather所有卡上的W进行bp, 完成后再抛弃其他卡维护的W
  4. 完成bp后, ReduceScatter所有卡的G, 从其他卡上取到需要需要更新的梯度增量, 然后释放不是自己维护的G.
  5. 使用自己维护的OS和G来更新W, 不需要通信.


通信量分析

定义单卡数据量为\(\phi\)

传统DP: bp完成后需要对梯度进行一次AllReduce, 一共\(2\phi\)

ZeRO1: 只舍弃了OS, bp时需要AllReduce G(Scatter+Gather 共\(2\phi\)). 另外在使用每张卡各自更新W时, 因为W每张卡都存储的全量, 需要从存储OS的卡上把对应更新后的W再拉回来, 所以需要一次Gather(\(\phi\)), 一共需要\(3\phi\)

ZeRO2: 舍弃了OS和G, bp时AllGather G(\(\phi\)), 更新W时从其他卡拉W, 再Gather一次(\(\phi\)), 一共需要\(2\phi\)

ZeRO3: 上面训练过程分析过, 共需要2次Gather和1次Scatter, 一共需要\(3\phi\)

可以看到ZeRO在通信量只增加了1.5倍的情况下, 显存降了60倍. 效果非常显著

ZeRO++

ZeRO存在的问题是会在GPU之间产生大量数据传输开销,降低了训练效率. 主要有两种情况:

  1. 全局batch size较小,而 GPU数量多,这导致每个 GPU 上batch size较小,需要频繁通信

  2. 在低端集群上进行训练,其中跨节点网络带宽有限,导致高通信延迟。

ZeRO++主要采用了3部分优化: 权重量化 (qwZ), 分层分割存储 (hpZ), 梯度量化 (qgZ). 对比ZeRO通信量减少了4倍, 主要的难点都在减小量化带来的训练误差

权重量化

    def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
q_range = 2**self.config['num_bits'] - 1
min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)
max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True) scale = q_range / (max_value - min_value) tensor = tensor.sub_(min_value).mul_(scale)
tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8) #对称式量化
return tensor, scale, min_value

量化kernel在deepspeedcsrc/quantization/quantize.cu cached_quantization 这个kernel里.

如果采用全局fp16->int8的量化会导致极大误差. deepspeed采用了分区量化的方法, 把参数分为固定大小的block后, 先根据这个block的max/min计算出scale(量化系数), 在把这个参数传入量化函数里. 另外在通信的时候应该也需要每个block对应的系数传给接收节点用于反量化.

\[量化公式: clip(round(scale * x), -2^{b-1}+1, 2^{b-1}-1)
\]

通过这种方式在通信量减半的同时还能保证精度, 很nice的思路.

分层分割存储

之前ZeRO的W切分方法是根据卡数均分. 在fp/bp之前进行AllGather拉取, 后来发现在机器间进行Gather通信是比较严重的瓶颈. 所以最后W的切分变成了每个节点内存储全量的W, 节点内根据卡数进行切片. 避免跨节点经过网卡的通信, 通过增加显存使用的方式解决通信瓶颈.

显存消耗: ZeRO3的单卡显存消耗为 $\frac{(2+2+K)*\phi}{N} \(, 这里每个节点多存了一份W, 如果有\)\alpha$个物理节点, 那么每张卡使用的显存就多了 \(\frac{\alpha * \phi}{N}\)

梯度量化

如果直接在之前zero RingAllReduce的通信方式上加量化和反量化, 如下图左, 可以看到需要节点个数次量化/反量化. 而每次量化都是有损的, 这样会导致无法接受的训练误差. 为了解决这个问题zero++使用了一次量化->AllToAll通信->一次反量化的操作. 而因为直接进行AllToAll通信量从M(参数量)变成了M*N/Z(N: 节点数, Z:量化压缩率), 这个通信量的增长过大. deepspeed设计了2-hpop all-to-all方法来解决通信问题.

具体图示流程可以参考Deepspeed的blog动态图, 文字版步骤:

  1. 节点内的卡间张量切片重排. 主要是因为alltoall切分成了两步, 如果不重排如下图左. 最后顺序会变错位, 然后进行参数量化

  2. 节点内alltoall通信后反量化.先把卡内能合并的梯度加起来. 这里反量化主要是为了减小梯度累加的精度损失

  3. 再次量化后, 节点间进行allToAll

  4. 拿到通信结果, 反量化后再次reduce. 得到最终的梯度.

这里要进行两次alltoall的原因主要是, 第一次卡间alltoall之后梯度累加可以减少卡数倍的通信规模. 实际deepspeed在实现的时候还把重分片和量化kernel进行了fuse, 进一步优化性能

还有下图的方法, 在通信当前层的时候, 通过多流异步量化下一层要通信的数据. 避免同步等待的浪费

参考

zero: https://arxiv.org/pdf/1910.02054

混合精度训练: https://arxiv.org/pdf/1710.03740

zero++: https://arxiv.org/abs/2306.10209

Deepspeed blog: https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md

LLM并行训练3-数据并行的更多相关文章

  1. C#并行编程之数据并行

    所谓的数据并行的条件是: 1.拥有大量的数据. 2.对数据的逻辑操作都是一致的. 3.数据之间没有顺序依赖. 运行并行编程可以充分的利用现在多核计算机的优势.记录代码如下: public class ...

  2. C#并行编程--命令式数据并行(Parallel.Invoke)---与匿名函数一起理解(转载整理)

    命令式数据并行   Visual C# 2010和.NETFramework4.0提供了很多令人激动的新特性,这些特性是为应对多核处理器和多处理器的复杂性设计的.然而,因为他们包括了完整的新的特性,开 ...

  3. C#并行编程--命令式数据并行(Parallel.Invoke)

    命令式数据并行   Visual C# 2010和.NETFramework4.0提供了很多令人激动的新特性,这些特性是为应对多核处理器和多处理器的复杂性设计的.然而,因为他们包括了完整的新的特性,开 ...

  4. ML2021 | (腾讯)PatrickStar:通过基于块的内存管理实现预训练模型的并行训练

    ​  前言  目前比较常见的并行训练是数据并行,这是基于模型能够在一个GPU上存储的前提,而当这个前提无法满足时,则需要将模型放在多个GPU上.现有的一些模型并行方案仍存在许多问题,本文提出了一种名为 ...

  5. TensorFlow分布式计算机制解读:以数据并行为重

    Tensorflow 是一个为数值计算(最常见的是训练神经网络)设计的流行开源库.在这个框架中,计算流程通过数据流程图(data flow graph)设计,这为更改操作结构与安置提供了很大灵活性.T ...

  6. SIMD数据并行(三)——图形处理单元(GPU)

    在计算机体系中,数据并行有两种实现路径:MIMD(Multiple Instruction Multiple Data,多指令流多数据流)和SIMD(Single Instruction Multip ...

  7. 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。

    百度为何开源深度机器学习平台?   有一系列领先优势的百度却选择开源其深度机器学习平台,为何交底自己的核心技术?深思之下,却是在面对业界无奈时的远见之举.   5月20日,百度在github上开源了其 ...

  8. PyTorch如何加速数据并行训练?分布式秘籍大揭秘

    PyTorch 在学术圈里已经成为最为流行的深度学习框架,如何在使用 PyTorch 时实现高效的并行化? 在芯片性能提升有限的今天,分布式训练成为了应对超大规模数据集和模型的主要方法.本文将向你介绍 ...

  9. 深度神经网络DNN的多GPU数据并行框架 及其在语音识别的应用

    深度神经网络(Deep Neural Networks, 简称DNN)是近年来机器学习领域中的研究热点,产生了广泛的应用.DNN具有深层结构.数千万参数需要学习,导致训练非常耗时.GPU有强大的计算能 ...

  10. 【深度学习系列2】Mariana DNN多GPU数据并行框架

    [深度学习系列2]Mariana DNN多GPU数据并行框架  本文是腾讯深度学习系列文章的第二篇,聚焦于腾讯深度学习平台Mariana中深度神经网络DNN的多GPU数据并行框架.   深度神经网络( ...

随机推荐

  1. Intel Pentium III CPU(Coppermine, Tualatin) L2 Cache Latency, Hardware Prefetch特性调查

    这几天,偶然的机会想到了困扰自己和其他网友多年的Intel Pentium III系列处理器缓存延迟(L2 Cache Latency),以及图拉丁核心版本是否支持硬件预取(Hardware Pref ...

  2. sql计算列中并非零值的平均值

    avg不考虑空值 AVG (NULLIF(Value, 0)) NULLIF(expression, expression) 如果两个 expression 相等,则返回 NULL,该 NULL 为第 ...

  3. docker 完美部署gitea

    效果: docker-compose version: "3" networks: gitea: external: false services: server: image: ...

  4. Ubuntu20.04桌面版图文安装(超详细)

    参考文档: https://baijiahao.baidu.com/s?id=1670100505795119581&wfr=spider&for=pc https://mirrors ...

  5. postgresql性能优化2:sql语句和缓存配置

    1.看执行计划 EXPLAIN, 此命令用于查看SQL的执行计划 总的来说sql的执行计划是一个树形层次结构, 一般来说阅读上遵从层级越深越优先, 同一层级由上到下的原则. 来跟着铁蛋老师读: 层级越 ...

  6. Python:解决Matplotlib保存图片显示不全问题

    保存图片的时候设置参数bbox_inches = 'tight',如: plt.savefig("Matplotlib/graph.png", bbox_inches = 'tig ...

  7. Python OpenCV #1 - OpenCV介绍

    一.OpenCV介绍 1.1 OpenCV-Python教程简介 OpenCV由 Gary Bradsky 于1999年在英特尔创立,第一个版本于2000年发布. Vadim Pisarevsky 加 ...

  8. 一个简单demo展示接口请求超时处理

    package main import ( "context" "errors" "fmt" "time" ) type ...

  9. Qt-FFmpeg开发-回调函数读取数据(8)

    音视频/FFmpeg #Qt Qt-FFmpeg开发-使libavformat解复用器通过自定义AVIOContext读取回调访问媒体内容 目录 音视频/FFmpeg #Qt Qt-FFmpeg开发- ...

  10. win10离线安装.net3.5失败的解决方案

    简介: 问题:有时候需要离线安装.net3.5环境,网上的教程一般都是通过NetFx3.cab进行离线安装,但有时候会出现离线安装失败,比如: by~MaQaQ 2024-06-04 分析: 1.先关 ...