前置知识

Activation

激活指的是一些在fp时计算得到的临时tensor, 会用于bp时的计算. 如果能在fp计算后把临时tensor缓存下来就可以加速bp, 缺点在于某些激活会占用大量显存. 以一层transformer结构为例分析下各层的存在的激活.

简单部分的分析这里忽略. 主要分析下几个不好理解的计算:

  1. \(QK^T\): 需要缓存Q输出和K输出 \(2sbh+2sbh\)
  2. \(softmax\): a个head, 每个head进行s次softmax概率分类, 一共有b*s个token, 输出为fp16, 所以是\(2 * a * s * (s * b) = 2as^2b\)
  3. \(dropout\): 需要缓存一个mask数组标记哪些位置的token需要被反向更新. 所以每个位置只需要bool, 所以是\(a * s * (s * b) = as^2b\)

如果激活的临时显存全部不释放, 每个transformerLayer需要占用的量级为\(sbh(34+5\frac{as}{h})\) (不包含任何并行加速的情况)

选择性激活重计算

重计算部分: 以GPT-3为例: a = 96, s = 2048, and h = 12288, 5as/h = 80. 占了总激活显存的70%左右的主要是attention里的部分比如: softmax的输出, dropout的mask, dropout的输出. 这些部分的激活都跟矩阵乘法没关系, 所需要的计算量很小. 把这些在前向计算完之后直接释放显存, 在bp用的时候再重计算的话就能以很小的计算代价换取70%的显存存储.

拆分存储部分: 而比如像线性层的中间输出, Q,K,V这些中间tensor. 因为前面有一个和W的矩阵乘法, 重计算代价巨大更适合临时缓存起来用于bp计算复用. 这些激活通过张量并行和序列并行把他们拆分存储, 在bp使用的时候通过集合通信的方式再拉取, 来减少显存消耗.

张量并行共计24sbh: 并行后拆分为t份的包括:

  • attention部分: QKV的输入2sbh, \(QK^T\)结果4sbh, 和V相乘结果里的2sbh. 共计8sbh
  • MLP部分: 两个8sbh, 因为TP的先列在行的计算方法全部平分成了t份, 共计16sbh

序列并行共计10sbh, 拆分成t份的包括:

  1. attention部分: layerNorm的输入输出4sbh, dropout的mask sbh
  2. MLP部分: layerNorm的输入输出4sbh, dropout的mask sbh

序列并行

序列并行指的是在 Transformer 层的非张量并行切分部分,计算在序列维度(s)是独立的. 所以在这个维度上以切分数量和张量并行数相同的方式进行切分.

为什么切分数要和TP相等呢? 回忆下TP后如何汇聚计算结果, 在行并行后allReduce, 把每张卡各自计算的结果纵向拼起来还原成完整的输入. 如果我们想把这部分完整的输入进行切分存储, 如果切分数量和TP不一致意味着在$\bar{g} $ 这个地方在allReduce之后还要再进行一次reduceScatter进行分割, 在\(g\)地方再allGather..

而如果切分数和TP相等, 在\(\bar{g}\)这里可以把allReduce直接省掉, 相当于把allReduce拆分的两个操作. 在通信量保持不变的情况下分离了layerNorm的激活

Zero-R

激活分区 & checkpointing

这里提到在模型并行的时候activation会存在冗余副本. 这里应该就是指的是TP输入的冗余副本. 论文里是说到会把激活给partition多份..其实我感觉实现方法和megatron里的序列并行就是一样的, 标记一下等细看deepspeed代码的时候再确认下.

另外还提到个新方法利用内存来存激活checkpoint, 想了下应该是类似下图的步骤

  1. 在最初始的几层fp和bp的时间间隔比较远, 适合在做完fp后memcpyAsync到内存.
  2. 靠近loss的后面几层激活还存在显存里, 在bp的时候直接用完就释放了. 快到cpu激活的部分通过memcpyAsync回来.
  3. 如果训练和copy激活能使用同一个stream, 那么这块就不需要同步, 按流的顺序实行即可

恒定大小显存缓冲区

像all-reduce这些集合通信操作, 在一次通信一批很大的数据效率很高. 但缺点是会分配大量的临时显存, 这样会导致显存出现较大波动, 在大模型场景会出现问题.

所以zero在这块设定了一个固定的buffer_size, 超过buffer_size的时候分批次通信. cpu激活checkpointing的copy应该也需要相应的方式.

其实在写flux-gpu的sparse拷贝的时候也用了类似的方法..分批次拷贝来避免单次的超大通信和小数据的碎片通信.

显存碎片解决方法

显存碎片产生的原因: 在fp的时候只有一部分激活储存下来用于bp, 另外一些需要在bp重算的激活被释放了.就会导致一部分显存的使用周期很长, 另一部分很短, 从而产生显存碎片. 会导致两个问题: 1. 显存allocator查找满足大小的显存块效率很低 2. 可能会出现大块连续显存分配不出来.

论文里说会给activation和grad预分配好连续显存块..emm, 这个做法看着和llm.c里的实现是一样的, 其实在大模型里大部分的w/grad/activation在运行的时候都是定长的, 我们完全可以在第一次运行的时候全部分配好. 在网络计算的时候避免显存分配. 如果不使用allocator就不会有碎片问题.

Zero-Offload

主要用来解决模型规模远大于显存规模的问题, 看着灰常似曾相识, 和部署在本地内存的参数服务器很像. 感觉区别在于2点: 1. 训练是同步的, gpu在cpu更新optimizer_state的时候只能处于等待状态 2.内存里保存全量的参数, 不进行多机通信(多机通信应该会让本来就慢的cpu更加雪上加霜吧haha).

计算策略

  1. 保证 CPU 的计算负担远远小于 GPU,从而防止 CPU 成为计算瓶颈;保证 GPU 的内存节省最大;(optimizer_state是最占显存的同时, 也是不需要反复计算的, 一个batch里只需要存取一次, 不像fp16的w一样还会参与反向的梯度计算)
  2. 保证 CPU 和 GPU 之间的通信量最小;(在通信的时候进行量化和反量化)

调度策略

offload采用的时候zero2方案, 也就是fp16 w是分卡存储的. 考虑使用zero2的很重要的一个原因我猜测是在于多卡可以同时copy w, 而且没有冗余数据通信. 避免pciE带宽拖后腿.

下图是单卡的数据流, swap的部分论文里画错了应该是CPU->GPU, 通信和计算异步的地方主要有2处:

  1. g offload, 是在gpu bp的时候每计算完一层的g就async copy到内存
  2. p swap, cpu更新完一批w, 就分块进行量化和async copy到显存.

Fp16 w到了显存里后就和不同的zero2计算流程完全一样了.

后面还有一个和推荐模型cpu异步训练类似的cpu操作全隐藏训练模式, 只不过区别是把异步训练的n个batch对齐dense改成了固定1个batch.

Zero-Offload++

在第一版offload的时候, 所有的参数都是在cpu计算的, 上面也说到了. 在cpu计算的时候gpu只能空等, 如何在空等的时间窗口把gpu利用起来是一个很大的问题. offload++给了一个很棒的思路. 设置了一个os_w的存储比例, 以图示为例, 有40%的os_w存在内存里由cpu更新, 剩下的60%由gpu更新. 步骤如下:

  1. 在bp完靠上层40%的网络后, 把g往内存copy
  2. CPU开始逐步计算已经拉下来的g, 更新os_w. 把属于自己更新的那部分算完
  3. 到达属于GPU更新的部分后, GPU allScatter 剩下的60% grad到os_w存储的对应卡上, 更新显存里的os_w
  4. 等cpu算完后量化的fp16_w copy回显存和显存里的fp16_w合并, 进行下一轮计算

这里的比值是人工设置的, 设置原理就是在尽量把显存用满的前提下尽可能的往GPU塞os_w, 塞不下的再放内存里. 这个思路感觉超棒, 待细看代码

参考

Megatron-LM论文: https://arxiv.org/pdf/2205.05198

zero-R论文: https://arxiv.org/abs/1910.02054

zero-offload: https://www.usenix.org/system/files/atc21-ren-jie.pdf

zero-offload++博客: https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp

megatron论文解读: https://diveblue.notion.site/Megatron-3-Reducing-Activation-Recomputation-in-Large-Transformer-Models-b4d8bfacd33c449383aa9c61123ab578#7c3cc0cb24c444b898c4e435d12bbd4f

LLM并行训练6-激活优化的更多相关文章

  1. ML2021 | (腾讯)PatrickStar:通过基于块的内存管理实现预训练模型的并行训练

    ​  前言  目前比较常见的并行训练是数据并行,这是基于模型能够在一个GPU上存储的前提,而当这个前提无法满足时,则需要将模型放在多个GPU上.现有的一些模型并行方案仍存在许多问题,本文提出了一种名为 ...

  2. PyTorch如何加速数据并行训练?分布式秘籍大揭秘

    PyTorch 在学术圈里已经成为最为流行的深度学习框架,如何在使用 PyTorch 时实现高效的并行化? 在芯片性能提升有限的今天,分布式训练成为了应对超大规模数据集和模型的主要方法.本文将向你介绍 ...

  3. [Pytorch框架] 4.5 多GPU并行训练

    文章目录 4.5 多GPU并行训练 4.5.1 torch.nn.DataParalle 4.5.2 torch.distributed 4.5.3 torch.utils.checkpoint im ...

  4. Pytorch:单卡多进程并行训练

    1 导引 我们在博客<Python:多进程并行编程与进程池>中介绍了如何使用Python的multiprocessing模块进行并行编程.不过在深度学习的项目中,我们进行单机多进程编程时一 ...

  5. tensorflow 13:多gpu 并行训练

    多卡训练模式: 进行深度学习模型训练的时候,一般使用GPU来进行加速,当训练样本只有百万级别的时候,单卡GPU通常就能满足我们的需求,但是当训练样本量达到上千万,上亿级别之后,单卡训练耗时很长,这个时 ...

  6. .Net中的并行编程-6.常用优化策略

                本文是.Net中的并行编程第六篇,今天就介绍一些我在实际项目中的一些常用优化策略.      一.避免线程之间共享数据 避免线程之间共享数据主要是因为锁的问题,无论什么粒度的锁 ...

  7. [源码解析] 模型并行分布式训练Megatron (5) --Pipedream Flush

    [源码解析] 模型并行分布式训练Megatron (5) --Pipedream Flush 目录 [源码解析] 模型并行分布式训练Megatron (5) --Pipedream Flush 0x0 ...

  8. [源码解析] PyTorch分布式优化器(3)---- 模型并行

    [源码解析] PyTorch分布式优化器(3)---- 模型并行 目录 [源码解析] PyTorch分布式优化器(3)---- 模型并行 0x00 摘要 0x01 前文回顾 0x02 单机模型 2.1 ...

  9. MySQL 并行复制演进及 MySQL 8.0 中基于 WriteSet 的优化

    MySQL 8.0 可以说是MySQL发展历史上里程碑式的一个版本,包括了多个重大更新,目前 Generally Available 版本已经已经发布,正式版本即将发布,在此将介绍8.0版本中引入的一 ...

  10. Java 进阶7 并发优化 1 并行程序的设计模式

       本章重点介绍的是基于 Java并行程序开发以及优化的方法,对于多核的 CPU,传统的串行程序已经很好的发回了 CPU性能,此时如果想进一步提高程序的性能,就应该使用多线程并行的方式挖掘 CPU的 ...

随机推荐

  1. gorm指定数据字段名字

    type Products struct { gorm.Model SaleNum uint ` json:"saleNum"` CarNum uint ` json:" ...

  2. 一个IDEA界面如何同时打开多个项目

    第一步:先导入其中一个工程 第二步:点击File->Project Structure 第三步:导入模块 最后点击Apply即可完成一个IDEA界面同时打开多个项目的需求.

  3. 七年之痒!一个 PHP 程序员职业生涯的自述

    大家好,我是码农先森. 今年刚好是我毕业的第七个年头,在婚姻感情当中都有一种「七年之痒」的说法,这次我把这个词「七年之痒」用一次在我的职业生涯复盘上.七年前我从告别校园,踏入互联网编程行业,七年后我依 ...

  4. 经验分享之会员 SaaS 系统

    经验分享之会员 SaaS 系统 一.前言 2018年,这是不平凡的一年:互联网行业的中台战略.会员经济等模式如火如荼,同时也逐渐地走入我们公司每个人的视野.在南海集团的战略规划背景下,当时我所在的公司 ...

  5. vue.js的M-V-VM思想

    MVVM 是Model-View-ViewModel 的缩写,它是一种基于前端开发的架构模式. Model 指代的就是vue对象的data属性里面的数据.这里的数据要显示到页面中. View 指代的就 ...

  6. java stream 简单函数

    写在前面 本文为笔者学习的一些心得,如有问题,评论请轻喷 本文分为以下部分: 中间操作 终止操作 归纳 中间操作 对 list 进行操作,返回一个新的 list 主要函数 作用 filter 过滤操作 ...

  7. 1.Prism

    Region(区域)在程序编写的过程中我们肯定会遇到在一个区域上显示不同的内容,这些内容可能属于不同窗口,之前是弄个panel,需要显示哪个窗口就给让panel显示. 1.定义区域2.提供对区域的访问 ...

  8. docker资源限制与数据持久化

    1. docker简介和核心概念 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使 ...

  9. GNU GCC学习

    1 Introduction 参考视频:1 GCC简介_哔哩哔哩_bilibili 参考书籍:<An Introduction to GCC (Brian J. Gough, Richard.p ...

  10. echarts做饼图

    今天记录下echarts做饼图 父组件 <el-card style="height:600px ;margin-top:20px" v-loading="card ...