开源 SD-Small 和 SD-Tiny 知识蒸馏代码与权重
最近,人工智能社区在开发更大、更高性能的语言模型方面取得了显著的进展,例如 Falcon 40B、LLaMa-2 70B、Falcon 40B、MPT 30B; 以及在图像领域的模型,如 SD2.1 和 SDXL 。这些进步无疑推动了人工智能的发展,使其具有高度多功能和最先进的图像生成和语言理解能力。然而,在我们惊叹于这些模型的强大和复杂性之余,必须认识到一个日益增长的需求: 使人工智能模型体量更小、运行更高效、更易于访问,特别是通过开源它们来共建生态。
在 Segmind,我们一直致力于如何使生成式 AI 更快、更便宜。去年,我们开源了我们加速的 SD-WebUI 库 voltaML,它是一个基于 AITemplate/TensorRT 的推理加速库,推理速度提高了 4-6 倍。为了继续实现使生成模型更快、更小、更便宜的目标,我们正在开源我们压缩的 SD 模型:SD-Small 和 SD-Tiny 的权重和训练代码。预训练的检查点可在 Hugging Face 上获取。
知识蒸馏
我们的新压缩模型已经经过知识蒸馏 (KD) 技术的训练,这项工作主要基于 这篇论文。作者描述了一种块移除知识蒸馏方法,其中一些 UNet 层被移除,学生模型权重被训练。使用论文中描述的 KD 方法,我们能够使用 diffusers 库训练两个压缩模型; Small(微小版本) 和 Tiny(极小版本),分别比基础模型少 35% 和 55% 的参数,同时实现与基础模型相当的图像保真度。我们已经在这个 repo 中开源了我们的蒸馏代码,并将预训练检查点上传到了 Hugging Face。
知识蒸馏训练神经网络类似于老师一步一步指导学生。一个大的老师模型 (teacher model) 预先在大量数据上训练,然后一个较小的模型在较小的数据集上训练,以模仿大模型的输出并在数据集上进行经典训练。
在这种特殊类型的知识蒸馏中,学生模型被训练来完成从纯噪声恢复图像的正常扩散任务,但同时,模型被迫与更大的老师模型的输出匹配。输出匹配发生在 U-nets 的每个块,因此模型质量基本保持不变。所以,使用前面的类比,我们可以说,在这种蒸馏过程中,学生不仅会试图从问题和答案中学习,还会从老师的答案以及逐步得到答案的方法中学习。我们在损失函数中有 3 个组成部分来实现这一点,首先是目标图像隐变量和生成图像隐变量之间的传统损失。其次是老师生成的图像隐变量和学生生成的图像隐变量之间的损失。最后,也是最重要的组成部分,是特征级损失,即老师和学生每个块输出之间的损失。
结合所有这些构成了知识蒸馏训练。下面是论文中描述的用于 KD 的块移除 UNet 架构。
图片来自 Shinkook 等人的 论文 “On Architectural Compression of Text-to-Image Diffusion Models”。
我们以 Realistic-Vision 4.0 为基础老师模型,并在LAION Art Aesthetic 数据集 上训练,图像分数高于 7.5,因为它们具有高质量的图像描述。与论文不同,我们选择分别为 Small 和 Tiny 模式训练两个模型,分别在 1M 张图像上进行 100K 步和 125K 步的训练。蒸馏训练的代码可以在 这里 找到。
模型使用
模型可以通过 diffusers 中的 DiffusionPipeline 来使用。
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained("segmind/small-sd", torch_dtype=torch.float16)
prompt = "Portrait of a pretty girl"
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
image = pipeline(prompt, negative_prompt = negative_prompt).images[0]
image.save("my_image.png")
推理延迟方面的速度表现
我们观察到,蒸馏模型比原始基础模型快了一倍。基准测试代码可以在 这里 找到。
潜在的局限性
蒸馏模型处于早期阶段,输出可能还不具备生产水平的质量。这些模型可能不是最好的通用模型,它们最好用作针对特定概念/风格进行微调或 LoRA 训练。蒸馏模型目前还不太擅长组合性或多概念。
在人像数据集上微调 SD-tiny 模型
我们已经在 Realistic Vision v4.0 模型生成的人像图像上微调了我们的 sd-tiny 模型。下面是使用的微调参数。
原版参数 | 中文释义 |
---|---|
Steps: 131000 | 步数: 131000 |
Learning rate: 1e-4 | 学习率: 1e-4 |
Batch size: 32 | 批量大小: 32 |
Gradient accumulation steps: 4 | 梯度累积步数: 4 |
Image resolution: 768 | 图像分辨率: 768 |
Dataset size: 7k images | 数据集大小: 7 千张图像 |
Mixed precision: fp16 | 混合精度: fp16 |
我们能够产生接近原始模型产生的图像质量,参数减少了近 40%,下面的样本结果不言自明:
微调基础模型的代码可以在 这里 找到。
LoRA 训练
在蒸馏模型上进行 LoRA 训练的一个优点是训练更快。下面是我们在蒸馏模型上对一些抽象概念进行的第一个 LoRA 训练的一些图像。LoRA 训练的代码可以在 这里 找到。
结论
我们邀请开源社区帮助我们改进并实现这些蒸馏 SD 模型的更广泛采用。用户可以加入我们的 Discord 服务器,在那里我们将宣布这些模型的最新更新,发布更多的检查点和一些令人兴奋的新 LoRAs。如果你喜欢我们的工作,请在我们的 Github 上点一下 star。
英文原文: https://hf.co/blog/sd_distillation
原文作者: Yatharth Gupta
译者: innovation64
审校/排版: zhongdongy (阿东)
开源 SD-Small 和 SD-Tiny 知识蒸馏代码与权重的更多相关文章
- 知识蒸馏(Distillation)
蒸馏神经网络取名为蒸馏(Distill),其实是一个非常形象的过程. 我们把数据结构信息和数据本身当作一个混合物,分布信息通过概率分布被分离出来.首先,T值很大,相当于用很高的温度将关键的分布信息从原 ...
- Deeplearning知识蒸馏
Deeplearning知识蒸馏 merge paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, ...
- 【论文考古】知识蒸馏 Distilling the Knowledge in a Neural Network
论文内容 G. Hinton, O. Vinyals, and J. Dean, "Distilling the Knowledge in a Neural Network." 2 ...
- 开源囧事4:你们这些卖代码的能不能留自己的QQ号?留我QQ号干嘛?
缘起于开源项目 从 2017 年开始,陆陆续续写了一些开源项目放到开源网站里,都是一些实战项目,给大家练练手.有基础整合的demo,有 Spring Boot 博客项目,有 Spring Boot 商 ...
- 【DKNN】Distilling the Knowledge in a Neural Network 第一次提出神经网络的知识蒸馏概念
原文链接 小样本学习与智能前沿 . 在这个公众号后台回复"DKNN",即可获得课件电子资源. 文章已经表明,对于将知识从整体模型或高度正则化的大型模型转换为较小的蒸馏模型,蒸馏非常 ...
- Android 存储到SD卡,获取SD的大小及可用空间
使用Sdcard注意事项: 1.权限问题: <uses-permission android:name="android.permission.WRIT ...
- eclipse中如何向开源中国(码云)上传代码
摘要 本文将介绍如何将本地的项目提交到开源中国上去,过程比较详细,实现起来很简单.由于自己也算是一个新手,所以没有做过多的解释,只是单纯的描述了该如何去做. 1.在开源中国上面新建一个空项目 到这 ...
- 分享一个开源的JavaScript统计图表库,40行代码实现专业统计图表
提升程序员工作效率的工具/技巧推荐系列 推荐一个功能强大的文件搜索工具SearchMyFiles 介绍一个好用的免费流程图和UML绘制软件-Diagram Designer 介绍Windows任务管理 ...
- 95)PHP,文件上传知识和代码
首先是知识总结: 上传: 从浏览器端传输的到服务器端. 请求时: 数据从浏览器端传输到服务器端. 可见: 上传,发生在浏览器向服务器发出请求过程中. 文件,对于浏览器来讲,就是表单中的一个特殊类型的数 ...
- wndows程序设计之书籍知识与代码摘录-封装一个类似printf的messagebox
//----------------------------------------- //本程序展示了如何实现MessageBoxPrintf函数 //本函数能像printf那样格式化输出 //摘录 ...
随机推荐
- exclude查询时出掉或排除某个条件的信息
exclude查询时出掉或排除某个条件的信息 print(Student.objects.all().exclude(nickname='A')
- < Python全景系列-7 > 提升Python编程效率:模块与包全面解读
欢迎来到我们的系列博客<Python全景系列>!在这个系列中,我们将带领你从Python的基础知识开始,一步步深入到高级话题,帮助你掌握这门强大而灵活的编程语法.无论你是编程新手,还是有一 ...
- 百度飞桨(PaddlePaddle) - PP-OCRv3 文字检测识别系统 预测部署简介与总览
百度飞桨(PaddlePaddle) - PP-OCRv3 文字检测识别系统 预测部署简介与总览 百度飞桨(PaddlePaddle) - PP-OCRv3 文字检测识别系统 Paddle Infer ...
- 基于AIGC的京东购物助手的技术方案设想
灵感来源 随着AIGC的爆火,ChatGPT,GPT-4的发布,我作为一个算法工作者,深感AI发展的迅猛.最近,OpenAI的插件和联网功能陆续向用户公开,我也在第一时间试用了这些最新的功能.在Ope ...
- Java并发(九)----线程join、interrupt
1.join 方法详解 1.1 为什么需要 join? 下面的代码执行,打印 r 是什么? static int r = 0; public static void main(String[] arg ...
- 使用脚本收发 protobuf 协议数据
问题背景 最近做了一个 ipv6 相关的功能,发现使用 getifaddrs 获取的本地 ipv6 地址有可能不是真实的网络 ipv6 地址: 例如上图中通过 getifaddrs 获得了多个本地 i ...
- 三分钟快速了解什么是MES系统
大家好,我是Edison. 近日我打算系统学习和整理一下MES/MOM系统相关的领域知识,从而构建我的业务域知识背景.万丈高楼平地起,我们先从快速了解什么是MES系统开始吧! 作为IT技术从业者,特别 ...
- CANoe学习笔记(四):UDS常用否定响应
UDS中定义的否定响应代码常用的: ServiceNotSupported/服务不支持($11 ) 当诊断仪发送的请求消息中服务标识符无法识别或不支持时,ECU应发送该响应码 SubFunctionN ...
- Hello Welcome to my blog!
Hello Welcome to my blog!
- 【HarmonyOS】一文教你如何在低代码项目中跳转H5页面
[关键字] 元服务.低代码.H5页面跳转.WebView [1.写在前面] 今天我们来实现一个在低代码项目中通过按钮跳转到H5页面的功能,本项目是基于API6的JS工程,我们的实现思路是在页面B中 ...