基于 Quanto 和 Diffusers 的内存高效 transformer 扩散模型
过去的几个月,我们目睹了使用基于 transformer 模型作为扩散模型的主干网络来进行高分辨率文生图 (text-to-image,T2I) 的趋势。和一开始的许多扩散模型普遍使用 UNet 架构不同,这些模型使用 transformer 架构作为扩散过程的主模型。由于 transformer 的性质,这些主干网络表现出了良好的可扩展性,模型参数量可从 0.6B 扩展至 8B。
随着模型越变越大,内存需求也随之增加。对扩散模型而言,这个问题愈加严重,因为扩散流水线通常由多个模型串成: 文本编码器、扩散主干模型和图像解码器。此外,最新的扩散流水线通常使用多个文本编码器 - 如: Stable Diffusion 3 有 3 个文本编码器。使用 FP16 精度对 SD3 进行推理需要 18.765GB 的 GPU 显存。
这么高的内存要求使得很难将这些模型运行在消费级 GPU 上,因而减缓了技术采纳速度并使针对这些模型的实验变得更加困难。本文,我们展示了如何使用 Diffusers 库中的 Quanto 量化工具脚本来提高基于 transformer 的扩散流水线的内存效率。
基础知识
你可参考 这篇文章 以获取 Quanto 的详细介绍。简单来说,Quanto 是一个基于 PyTorch 的量化工具包。它是 Hugging Face Optimum 的一部分,Optimum 提供了一套硬件感知的优化工具。
模型量化是 LLM 从业者必备的工具,但在扩散模型中并不算常用。Quanto 可以帮助弥补这一差距,其可以在几乎不伤害生成质量的情况下节省内存。
我们基于 H100 GPU 配置进行基准测试,软件环境如下:
除非另有说明,我们默认使用 FP16 进行计算。我们不对 VAE 进行量化以防止数值不稳定问题。你可于 此处 找到我们的基准测试代码。
截至本文撰写时,以下基于 transformer 的扩散模型流水线可用于 Diffusers 中的文生图任务:
另外还有一个基于 transformer 的文生视频流水线: Latte。
为简化起见,我们的研究仅限于以下三个流水线: PixArt-Sigma、Stable Diffusion 3 以及 Aura Flow。下表显示了它们各自的扩散主干网络的参数量:
模型 | Checkpoint | 参数量(Billion) |
---|---|---|
PixArt | https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS | 0.611 |
Stable Diffusion 3 | https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers | 2.028 |
Aura Flow | https://huggingface.co/fal/AuraFlow/ | 6.843 |
用 Quanto 量化 DiffusionPipeline
使用 Quanto 量化模型非常简单。
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import PixArtSigmaPipeline
import torch
pipeline = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")
quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)
我们对需量化的模块调用 quantize()
,以指定我们要量化的部分。上例中,我们仅量化参数,保持激活不变,量化数据类型为 FP8。最后,调用 freeze()
以用量化参数替换原始参数。
然后,我们就可以如常调用这个 pipeline
了:
image = pipeline("ghibli style, a fantasy landscape with castles").images[0]
FP16 | 将 transformer 扩散主干网络量化为 FP8 |
---|---|
我们注意到使用 FP8 可以节省显存,且几乎不影响生成质量; 我们也看到量化模型的延迟稍有变长:
Batch Size | 量化 | 内存 (GB) | 延迟 (秒) |
---|---|---|---|
1 | 无 | 12.086 | 1.200 |
1 | FP8 | 11.547 | 1.540 |
4 | 无 | 12.087 | 4.482 |
4 | FP8 | 11.548 | 5.109 |
我们可以用相同的方式量化文本编码器:
quantize(pipeline.text_encoder, weights=qfloat8)
freeze(pipeline.text_encoder)
文本编码器也是一个 transformer 模型,我们也可以对其进行量化。同时量化文本编码器和扩散主干网络可以带来更大的显存节省:
Batch Size | 量化 | 是否量化文本编码器 | 显存 (GB) | 延迟 (秒) |
---|---|---|---|---|
1 | FP8 | 否 | 11.547 | 1.540 |
1 | FP8 | 是 | 5.363 | 1.601 |
4 | FP8 | 否 | 11.548 | 5.109 |
4 | FP8 | 是 | 5.364 | 5.141 |
量化文本编码器后生成质量与之前的情况非常相似:
上述攻略通用吗?
将文本编码器与扩散主干网络一起量化普遍适用于我们尝试的很多模型。但 Stable Diffusion 3 是个特例,因为它使用了三个不同的文本编码器。我们发现 _ 第二个 _ 文本编码器量化效果不佳,因此我们推荐以下替代方案:
- 仅量化第一个文本编码器 (
CLIPTextModelWithProjection
) 或 - 仅量化第三个文本编码器 (
T5EncoderModel
) 或 - 同时量化第一个和第三个文本编码器
下表给出了各文本编码器量化方案的预期内存节省情况 (扩散 transformer 在所有情况下均被量化):
Batch Size | 量化 | 量化文本编码器 1 | 量化文本编码器 2 | 量化文本编码器 3 | 显存 (GB) | 延迟 (秒) |
---|---|---|---|---|---|---|
1 | FP8 | 1 | 1 | 1 | 8.200 | 2.858 |
1 | FP8 | 0 | 0 | 1 | 8.294 | 2.781 |
1 | FP8 | 1 | 1 | 0 | 14.384 | 2.833 |
1 | FP8 | 0 | 1 | 0 | 14.475 | 2.818 |
1 | FP8 | 1 | 0 | 0 | 14.384 | 2.730 |
1 | FP8 | 0 | 1 | 1 | 8.325 | 2.875 |
1 | FP8 | 1 | 0 | 1 | 8.204 | 2.789 |
1 | 无 | - | - | - | 16.403 | 2.118 |
量化文本编码器: 1 | 量化文本编码器: 3 | 量化文本编码器: 1 和 3 |
---|---|---|
其他发现
在 H100 上 bfloat16
通常表现更好
对于支持 bfloat16
的 GPU 架构 (如 H100 或 4090),使用 bfloat16
速度更快。下表列出了在我们的 H100 参考硬件上测得的 PixArt 的一些数字: Batch Size 精度 量化 显存 (GB) 延迟 (秒) 是否量化文本编码器
Batch Size | 精度 | 量化 | 显存(GB) | 延迟(秒) | 是否量化文本编码器 |
---|---|---|---|---|---|
1 | FP16 | INT8 | 5.363 | 1.538 | 是 |
1 | BF16 | INT8 | 5.364 | 1.454 | 是 |
1 | FP16 | FP8 | 5.363 | 1.601 | 是 |
1 | BF16 | FP8 | 5.363 | 1.495 | 是 |
qint8
的前途
我们发现使用 qint8
(而非 qfloat8
) 进行量化,推理延迟通常更好。当我们对注意力 QKV 投影进行水平融合 (在 Diffusers 中调用 fuse_qkv_projections()
) 时,效果会更加明显,因为水平融合会增大 int8 算子的计算维度从而实现更大的加速。我们基于 PixArt 测得了以下数据以证明我们的发现:
Batch Size | 量化 | 显存 (GB) | 延迟 (秒) | 是否量化文本编码器 | QKV 融合 |
---|---|---|---|---|---|
1 | INT8 | 5.363 | 1.538 | 是 | 否 |
1 | INT8 | 5.536 | 1.504 | 是 | 是 |
4 | INT8 | 5.365 | 5.129 | 是 | 否 |
4 | INT8 | 5.538 | 4.989 | 是 | 是 |
INT4 咋样?
在使用 bfloat16
时,我们还尝试了 qint4
。目前我们仅支持 H100 上的 bfloat16
的 qint4
量化,其他情况尚未支持。通过 qint4
,我们期望看到内存消耗进一步降低,但代价是推理延迟变长。延迟增加的原因是硬件尚不支持 int4 计算 - 因此权重使用 4 位,但计算仍然以 bfloat16
完成。下表展示了 PixArt-Sigma 的结果:
Batch Size | 是否量化文本编码器 | 显存 (GB) | 延迟 (秒) |
---|---|---|---|
1 | 否 | 9.380 | 7.431 |
1 | 是 | 3.058 | 7.604 |
但请注意,由于 INT4 量化比较激进,最终结果可能会受到影响。所以,一般对于基于 transformer 的模型,我们通常不量化最后一个投影层。在 Quanto 中,我们做法如下:
quantize(pipeline.transformer, weights=qint4, exclude="proj_out")
freeze(pipeline.transformer)
"proj_out"
对应于 pipeline.transformer
的最后一层。下表列出了各种设置的结果:
量化文本编码器: 否 , 不量化的层: 无 | 量化文本编码器: 否 , 不量化的层: "proj_out" | 量化文本编码器: 是 , 不量化的层: 无 | 量化文本编码器: 是 , 不量化的层: "proj_out" |
---|---|---|---|
为了恢复损失的图像质量,常见的做法是进行量化感知训练,Quanto 也支持这种训练。这项技术超出了本文的范围,如果你有兴趣,请随时与我们联系!
本文的所有实验结果都可以在 这里 找到。
加个鸡腿 - 在 Quanto 中保存和加载 Diffusers 模型
以下代码可用于对 Diffusers 模型进行量化并保存量化后的模型:
from diffusers import PixArtTransformer2DModel
from optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8
model = PixArtTransformer2DModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="transformer")
qmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8)
qmodel.save_pretrained("pixart-sigma-fp8")
此代码生成的 checkpoint 大小为 587MB ,而不是原本的 2.44GB。然后我们可以加载它:
from optimum.quanto import QuantizedPixArtTransformer2DModel
import torch
transformer = QuantizedPixArtTransformer2DModel.from_pretrained("pixart-sigma-fp8")
transformer.to(device="cuda", dtype=torch.float16)
最后,在 DiffusionPipeline
中使用它:
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
transformer=None,
torch_dtype=torch.float16,
).to("cuda")
pipe.transformer = transformer
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
将来,我们计划支持在初始化流水线时直接传入 transformer
就可以工作:
pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- transformer=None,
+ transformer=transformer,
torch_dtype=torch.float16,
).to("cuda")
QuantizedPixArtTransformer2DModel
实现可参考 此处。如果你希望 Quanto 支持对更多的 Diffusers 模型进行保存和加载,请在 此处 提出需求并 @sayakpaul
。
小诀窍
- 根据应用场景的不同,你可能希望对流水线中不同的模块使用不同类型的量化。例如,你可以对文本编码器进行 FP8 量化,而对 transformer 扩散模型进行 INT8 量化。由于 Diffusers 和 Quanto 的灵活性,你可以轻松实现这类方案。
- 为了优化你的用例,你甚至可以将量化与 Diffuser 中的其他 内存优化技术 结合起来,如
enable_model_cpu_offload()
。
总结
本文,我们展示了如何量化 Diffusers 中的 transformer 模型并优化其内存消耗。当我们同时对文本编码器进行量化时,效果变得更加明显。我们希望大家能将这些工作流应用到你的项目中并从中受益。
感谢 Pedro Cuenca 对本文的细致审阅。
英文原文: https://hf.co/blog/quanto-diffusers
原文作者: Sayak Paul,David Corvoysier
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。
基于 Quanto 和 Diffusers 的内存高效 transformer 扩散模型的更多相关文章
- Java安全之基于Tomcat的Filter型内存马
Java安全之基于Tomcat的Filter型内存马 写在前面 现在来说,内存马已经是一种很常见的攻击手法了,基本红队项目中对于入口点都是选择打入内存马.而对于内存马的支持也是五花八门,甚至各大公司都 ...
- 三维CAD塑造——基于所述基本数据结构一半欧拉操作模型
三维CAD塑造--基于所述基本数据结构一半欧拉操作模型(elar, B_REP) (欧拉操作 三维CAD建模课程 三维CAD塑造 高曙明老师 渲染框架 brep 带洞 带柄 B_REP brep ...
- 牛亚男:基于多Domain多任务学习框架和Transformer,搭建快精排模型
导读: 本文主要介绍了快手的精排模型实践,包括快手的推荐系统,以及结合快手业务展开的各种模型实战和探索,全文围绕以下几大方面展开: 快手推荐系统 CTR模型--PPNet 多domain多任务学习框架 ...
- 基于HTML5的WebGL应用内存泄露分析
上篇(http://www.hightopo.com/blog/194.html)我们通过定制了CPU和内存展示界面,体验了HT for Web通过定义矢量实现图形绘制与业务数据的代码解耦及绑定联动, ...
- 基于JDK1.8的JVM 内存结构【JVM篇三】
目录 1.内存结构还是运行时数据区? 2.运行时数据区 3.线程共享:Java堆.方法区 4.线程私有:程序计数器.Java 虚拟机栈.本地方法栈 5.JVM 内存结构总结 在我的上一篇文章别翻了,这 ...
- [转载]查看基于Android 系统单个进程内存、CPU使用情况的几种方法
转载自: http://www.linuxidc.com/Linux/2011-11/47587.htm 一.利用Android API函数查看1.1 ActivityManager查看可用内存. A ...
- R 语言中 data table 的相关,内存高效的 增量式 data frame
面对的是这样一个问题,不断读入一行一行数据,append到data frame上,如果用dataframe, rbind() ,可以发现数据大的时候效率明显变低. 原因是 每次bind 都是一次重新 ...
- 撸代码--类QQ聊天实现(基于linux 管道 信号 共享内存)
一:任务描写叙述 A,B两个进程通过管道通信,像曾经的互相聊天一样,然后A进程每次接收到的数据通过A1进程显示(一个新进程,用于显示A接收到的信息),A和A1间的数据传递採用共享内存,相应的有一个B1 ...
- 基于JVM原理、JMM模型和CPU缓存模型深入理解Java并发编程
许多以Java多线程开发为主题的技术书籍,都会把对Java虚拟机和Java内存模型的讲解,作为讲授Java并发编程开发的主要内容,有的还深入到计算机系统的内存.CPU.缓存等予以说明.实际上,在实际的 ...
- 基于贝叶斯网(Bayes Netword)图模型的应用实践初探
1. 贝叶斯网理论部分 笔者在另一篇文章中对贝叶斯网的理论部分进行了总结,在本文中,我们重点关注其在具体场景里的应用. 2. 从概率预测问题说起 0x1:条件概率预测模型之困 我们知道,朴素贝叶斯分类 ...
随机推荐
- C# 时间戳与 标准时间互转
C# 时间戳与 标准时间的转其实不难,但需要注意下,基准时间的问题. 格林威治时间起点: 1970 年 1 月 1 日的 00:00:00.000 北京时间起点:1970 年 1 月 1 日的 08: ...
- power bi 如何删除敏感度标签
经验证,此方法不够彻底,我的office excel打开后还是要添加敏感度标签,即使我把敏感度标签删掉也不行. 当我把创建敏感度标签的管理员账户删掉之后,虽然打开excel还是会显示敏感度标签,但是已 ...
- ubuntu podman相关
前言 记录podman的安装.配置以及一些常用操作,会不定时更新: 正文 1. podman 安装以及配置 ubuntu 安装 podman sudo apt update sudo apt inst ...
- Flask API 如何接入 i18n 实现国际化多语言
1. 介绍 上一篇文章分享了 Vue3 如何如何接入 i18n 实现国际化多语言,这里继续和大家分享 Flask 后端如何接入 i18n 实现国际化多语言. 用户请求 API 的多语言化其实有两种 ...
- leetcode 中等(设计):[146, 155, 208, 211, 284, 304, 307, 341, 355, 380]
目录 146. LRU 缓存 155. 最小栈 208. 实现 Trie (前缀树) 211. 添加与搜索单词 - 数据结构设计 284. 顶端迭代器 304. 二维区域和检索 - 矩阵不可变 307 ...
- oeasy教您玩转vim - 2 - # 使用帮助
回忆上节课内容 更新和运行 vim 进入和退出 vim 存活了下来 从中我们知道 vim 有两种模式:正常模式(Normal mode)和命令行模式 (Command-Line mode) 为了您能更 ...
- PHP进阶
只是简要说明起原理和用法,具体可以百度 abstract 抽象类 抽象类是指在 class 前加了 abstract 关键字且存在抽象方法,不带{},如public function test() i ...
- 【WPF】Command 的一些使用方案
Command,即命令,具体而言,指的是实现了 ICommand 接口的对象.此接口要求实现者包含这些成员: 1.CanExecute 方法:确定该命令是否可以执行,若可,返回 true:若不可,返回 ...
- Python异常处理try+except用法
1.except是用来捕获程序异常的 异常代码如: ModuleNotFoundError(没有找到模块,安装提示的模块即可) AttributeError(没有访问属性) TypeError(类型错 ...
- PHP转Go系列 | Carbon 时间处理工具的使用姿势
大家好,我是码农先森. 在日常的开发过程中经常会遇到对时间的处理,比如将时间戳进行格式化.获取昨天或上周或上个月的时间.基于当前时间进行加减等场景的使用.在 PHP 语言中有一个针对时间处理的原生函数 ...