良好的权重初始化可以有效降低深度神经网络(DNN)模型的训练成本。如何初始化参数的选择是一个具有挑战性的任务,可能需要手动调整,这可能既耗时又容易出错。为了解决这些限制,论文迈出了建立权重生成器以合成神经网络初始化权重的创新一步。采用图像到图像的转换任务,使用生成对抗网络(GAN)作为示例,因为这方面的模型权重收集相对简单。

具体而言,首先收集了一个包含各种图像编辑概念及其对应训练权重的数据集,这些数据集随后用于权重生成器的训练。为了应对层之间的不同特性及需要预测的权重数量庞大,将权重划分为相等大小的块,并为每个块分配一个索引。随后,使用文本条件(即概念)和块索引的这种数据集来训练扩散模型。通过用扩散模型预测的去噪权重初始化图像转换模型,训练只需 43.3秒。与从头开始训练(即Pix2pix)相比,实现了新概念训练时间加速15\times的效果,同时获得了更好的图像生成质量。

来源:晓飞的算法工程笔记 公众号,转载请注明出处

论文: Efficient Training with Denoised Neural Weights

Introduction


高效训练深度神经网络(DNN)不仅加快了模型开发过程,还降低了对计算资源和成本的要求。许多之前的研究探讨了高效训练策略,如稀疏训练和低比特训练。然而,实现高效训练往往受到有效初始化模型权重的挑战所阻碍。虽然在权重初始化领域已经采取了一些措施,但在不同任务中确定合适的方案仍然具有挑战性。调节权重初始化的参数可能会耗时且容易出错,导致性能不佳和训练时间增加。

为了解决这些挑战,受到最近HyperNetworks设计进展的启发,论文首次研究了构建一个权重生成器的可行性,以在不同任务中提供更好的权重初始化,减少获得经过良好训练的DNN模型所需的训练时间和资源消耗。以使用GAN模型训练的图像到图像翻译任务为例,展开在预测神经权重方面的设计。需要注意,论文的框架是一个通用设计,并不局限于生成GAN权重,选择这个例子的原因在于可以轻松获取在不同数据集上训练的海量不同权重。

更具体地说,权重生成器可以为未见的新概念和风格预测初始化权重。为了减少需要预测的权重数量,将低秩适配(LoRA)应用于图像生成模型,从而在保持高质量图像生成的同时,大幅减少模型参数的数量。由于GAN模型由不同类型的层组成,且具有不同的权重大小和数量,对权重进行分组,并将其划分为大小相等的块。利用扩散过程来建模GAN模型的训练好的权重空间,通过训练扩散模型进行权重估计,即权重生成器。为了提高权重生成器的性能,进一步将块索引作为权重生成器中的一个条件机制,采用正弦位置编码方案,并计算块索引的嵌入。该嵌入为权重生成器提供关于每个权重块在所有模型权重中的位置的信息。在获得权重生成器后,为了训练一个基于GAN的图像翻译模型,通过单步去噪过程快速推断权重生成器,并使用预测的权重来初始化GAN模型。GAN模型只需随后进行高效的微调过程即可获得高质量的图像生成结果,显著减少获得新颖未见概念模型的时间消耗。

论文贡献总结如下:

  1. 提出了一个框架,用于生成不同概念/风格的权重初始化,以高效训练用于图像翻译的GAN模型。
  2. 在扩散模型的帮助下(即准备成对的图像数据集),收集了大量不同概念/风格的LoRA权重的真实数据集,这为权重生成器的训练奠定了基础。
  3. 通过利用扩散过程,引入了一种高效的权重生成器设计,该设计将文本概念信息和块索引作为输入。为了处理不同的层类型和权重形状,将权重组织为大小相等的一维块,显著减少了计算开销。通过将块索引与时间步(time step)嵌入相结合,这些块索引被无缝集成到权重生成器设计中。因此,权重生成器掌握了每个权重块在所有模型权重中的位置信息。
  4. 提出的框架可以通过单次去噪步骤预测GAN模型的初始化神经权重,仅需 \(1.19\) 秒。通过使用预测的权重进行初始化,快速微调过程可以在 \(42.1\) 秒内传达目标风格。与从头训练(即Pix2pix)相比,将总训练时间减少了 \(15\times\) ,同时保持了更好的图像生成质量。与其他高效训练方法相比,可以节省 \(4.6\times\) 的训练时间。

Motivations and Challenges


有效的权重初始化对稳定训练至关重要,能够促进更快的学习速率,加速收敛,并增强泛化能力。然而,在不同任务中确定良好的权重初始化仍然具有挑战性。受到最近超网络(HyperNetwork)进展的启发,论文希望调查是否可以构建一个权重生成器来获取良好的权重初始化,从而减少训练时间和资源消耗。与流行的图像/视频生成不同,探索权重生成的研究工作相对较少。构建这样的权重生成器前景广阔但也面临挑战。

第一个重大挑战来自深度神经网络(DNN)架构中的不同层类型。每一层的权重具有不同的大小和形状,这就需要一种能够适应这种异质性的权重生成方法。其次,权重生成器必须具备高效生成大量参数的能力,以确保网络的全面覆盖。第三,权重生成器的推理过程应快速有效,以节省为新任务获取权重的时间。

解决这些挑战有望构建出效率更高且有效性更强的深度学习系统的DNN训练模式。因此,在本研究中,论文研究了权重生成器的构建,以实现更好的权重初始化。论文旨在展示权重生成能力不仅限于在特定数据集上对单一模型架构的权重初始化,如基于 在CIFAR-10上的ResNet-18,而是适用于不同任务的多种模型。为了实现这一目标,以GANs在图像到图像转换任务中的初始化权重生成为例,因为收集多样化数据集用于GAN模型相对容易,但论文的方法并不局限于GAN架构或图像到图像转换任务。

Method


目标是训练一个权重生成器,以预测不同任务的权重初始化。以GANs在图像到图像转换任务中的应用为例,当出现新概念/风格时,可以查询权重生成器以提供初始化所需的权重值。权重生成器采用扩散过程建模,如图1所示。

与从纯噪声反转出干净图像的图像扩散模型不同,该框架旨在将噪声转化为用于初始化的权重值。通过插入预测的权重值,快速微调过程得以进行,以实现目标风格的GAN模型的高效训练。框架的核心是权重生成器的设计。

Dataset Collection

为了有效地训练一个权重生成器,用于生成不同概念的GAN模型的权重初始化,需要收集一个大规模的真实权重值数据集。为了获得真实权重值数据集,大规模的提示数据集显得尤为重要。通过使用提示数据集中的概念/风格,可以利用扩散模型进行图像收集,从而获得每个目标概念的代表性图像的丰富集合。每个概念/风格的图像进一步被利用来训练GAN,以获得真实的GAN权重。

作为权重生成器训练的数据准备基础,提示数据集应包括多样化的视觉概念/风格,以使权重生成器能够学习全面的表示,用于初始化针对特定任务的GAN。然而,收集这样一个数据集的过程面临巨大的挑战。确保不同概念/风格之间的多样性和代表性需要大量的数据。此外,收集到的提示还会进一步用于利用扩散模型生成目标概念/风格的图像。

为了构建提示数据集以训练一个可靠的权重生成器用于GAN权重初始化,采用了一种系统的方法,结合大语言模型 (LLMs) 进行风格生成和增强,以确保概念表现的丰富性和多样性。首先概述三个广泛的类别:1)艺术概念,2)特征概念,以及3)面部修改概念。在每个类别中,利用一个大语言模型(ChatGPT-3.5)来请求生成一系列包含各种概念的文本描述。通过过滤冗余的概念/风格,进一步通过查询另一个大语言模型(Vicuna)来实施增强方法,以提供具有相似含义但不同表现的概念/风格。为了进一步丰富提示数据集,还在不同类别之间排列和组合概念/风格。通过这个过程,能够构造一个大规模的提示数据集,不仅涵盖了多样化的概念领域,且捕捉了复杂的风格差异,为权重生成器的训练提供了更好的权重初始化基础。

在提示数据集收集之后,使用扩散模型编辑真实图像,以获得提示数据集中每个概念/风格的编辑图像,形成用于GAN训练的数据对。在这里,采用了一种混合了ResNet块和Transformer块的生成器(E2gan)作为训练模型,这种模型的有效性和混合架构设计能够展示不同类型层上的生成能力。在GAN训练过程后,为不同的概念/风格建立一个来自GAN检查点的权重数据集。为了进一步增强权重值数据集,在FID指标收敛后,为每个概念/风格保存 \(K\) 个检查点。

Data Format Design for Weight Generator

为了训练一个能够高效生成适用于不同概念的GAN模型权重初始化的权重生成器,设计训练和推断的权重格式是非常重要的。目标是每当新的概念作为输入提供给权重生成器时,它能够为该概念生成所有层的权重初始化。考虑到模型中存在多种不同类型的层,例如全连接层(FC)、卷积层(CONV)和批量归一化层(BN),以及不同层之间的大小和维度差异,设计合适的数据格式变得至关重要且具有挑战性。此外,GAN模型中的权重规模通常在百万级别,这为数据格式设计带来了更多挑战。

需要预测的权重数量越大,权重生成器面临的困难就越多。为了减轻这一问题,对不同层应用低秩适配(LoRA),以大幅减少需要预测的权重数量。例如,对于一个权重为 \(\mathbf{w}_i \in \mathbb{R}^{c\times f \times k_h\times k_w}\) 的卷积层 \(i\) ,应用两个秩为 \(r_i\) 的低秩矩阵,即 \(\mathbf{w}_{i}^A \in \mathbb{R}^{c\times r_i \times k_h \times k_w}\) 作为LoRA下层, \(\mathbf{w}_{i}^B \in \mathbb{R}^{r_i \times f \times 1\times1}\) 作为LoRA上层,以近似权重变化。通过这样做,待预测的权重总量从7.06M减少到0.22M。微调LoRA权重足以转移GAN模型的生成领域,尽管大大减少了权重数量,但通过权重生成器一次性直接预测所有0.22M权重仍然是具有挑战性的。这需要一个大的权重生成器,并且会带来巨大的计算和内存负担。

为了解决这个问题,将权重划分为多个组,以减轻计算复杂性,并增强在训练和推断期间将权重生成器适配到内存中的可行性。由于不同层具有不同的统计特性,将每个层 \(i\) 的LoRA下层和上层(如果适用,还包括相关的BN层)分为一个组。尽管如此,每个组的权重数量和形状仍然不同。因此,进一步将权重展平为一维向量,并将权重划分为 \(N\) 个等大小的块,每个块包含 \(b\) 个权重。

于是,数据格式表示为 \(<n, \mathbf{w}_n, T>\) ,其中 \(n\) 是块索引, \(\mathbf{w}_n \in \mathbb{R}^b\) 是第 \(n\) 个权重块的展平一维权重向量, \(T\) 表示当前概念/风格的文本提示。使用这种数据格式的优点包括:1)适用于不同类型和形状的层;2)降低了计算复杂性和预测难度;3)使权重生成器更容易适配到内存中。

Weight Generator Training

使用论文的权重值数据集训练一个生成模型,该模型学习为其他概念/风格提供权重初始化。通过扩散过程对GAN的权重初始化空间进行建模。生成器是一个UNet权重信息生成器 \(\hat{\mathbf{\epsilon}}_\theta\) ,其参数为 \(\theta\) ,用于一维向量,如图2所示。将权重块 \(\mathbf{w}_n\) 从真实权重分布 \(p(\mathbf{w}_n)\) 扩散(迭代)为一个噪声版本,并训练去噪UNet逐渐逆转这个过程,从高斯噪声中生成权重。训练可以形式化为以下噪声预测问题:

\[\begin{equation}
\min_\theta \mathbb{E}[\| \hat{\epsilon}_\theta(\mathbf{w}_n^t,t,n,\tau(T)) - \mathbf{\epsilon} \|_2^2],
\end{equation}
\]

其中 \(t\) 表示时间步; \(\epsilon\) 是真实噪声; \(\mathbf{w}_n^t = \alpha_t \mathbf{w}_n + \sigma_t \epsilon\) 是块 \(n\) 的噪声权重; \(\alpha_t\) 和 \(\sigma_t\) 分别是信号和噪声的强度,由噪声调度器决定; \(\tau\) 是一个冻结的文本编码器,如CLIP

为了将块索引作为权重生成器中的进一步条件机制,采用来自于序列到序列模型中常用的正弦位置编码。计算正弦块索引编码,该编码用于向权重生成器提供有关每个权重块在所有模型权重中的位置的信息。具体而言,令 \(N\) 表示权重块的总数, \(d\) 表示编码的维度。块索引 \(n\) 的正弦块索引编码 \(\text{SinEnc}(n, d)\) 计算如下:

\[\begin{equation}
\text{SinEnc}(n, 2i) = \sin\left(\frac{n}{10000^{2i/d}}\right), \text{SinEnc}(n, 2i + 1) = \cos\left(\frac{n}{10000^{2i/d}}\right),
\end{equation}
\]

其中 \(i\) 从0到 \(\left\lfloor\frac{d-1}{2}\right\rfloor\) 。将正弦编码输入到嵌入层中,以获得块索引嵌入 \(emb\_n\) ,将块索引嵌入 \(emb\_n\) 与时间步嵌入 \(emb\_t\) 结合,表示为 \(emb = emb\_n + emb\_t\) ,以便在生成器的每个残差块中使用。因此,权重生成器在整个去噪过程中都可以访问块索引。根据结果,论文观察到块索引 \(n\) 能够有效地建模来自不同块的权重,而不必依赖于先前预测的权重,从而大大减少了计算量。

Fast Fine-Tuning with Generated Weight Initializations

当出现一个新概念/风格 \(T\) 时,可以通过对每个权重块 \(n\) 进行已训练权重生成器 \(\hat{\epsilon}_\theta\) 的推理来获得权重初始化。为了快速获取权重初始化,采用直接重建方法以避免迭代去噪过程。更具体地说,选定的偏向噪声的时间步 \(t\) ,推理去噪扩散模型来预测噪声 \(\hat{\epsilon}_\theta(\mathbf{w}_n^t, t, n, \tau(T))\) ,并进行直接恢复以获得真实权重 \(\mathbf{w}_{n}=\mathbf{w}_{n}^0\) :

\[\begin{equation}
\mathbf{w}_{n}^0 = \frac{1}{\alpha_t} \mathbf{w}_{n}^t - \sigma_t \hat{\epsilon}_\theta(\mathbf{w}_n^t,t,n,\tau(T)).
\end{equation}
\]

在对所有 \(N\) 个权重块进行推理之后,可以获得概念/风格 \(T\) 的权重初始化 \(\{\mathbf{w}_{n} \}_{n=1}^N\) 。

为了更好地捕捉新概念/风格的细节,利用条件GAN损失对GAN权重进行进一步的微调,具体如下:

\[\begin{equation}
\begin{aligned}
&\min_{\mathbf{w}_{lora}} \max_{\mathbf{w}_d} \lambda \underbrace{ \mathbb{E}_{\mathbf{x},\tilde{\mathbf{x}}^T,\mathbf{z}, T} \left[ \| \tilde{\mathbf{x}}^T - \mathcal{G}(\mathbf{x}, \mathbf{z}, T;\mathbf{w}_g,\mathbf{w}_{lora}) \|_1 \right]}_{\textrm{$\ell_1$ loss}} + \\ &\underbrace{\mathbb{E}_{\mathbf{x}, \tilde{\mathbf{x}}^T} \left[\log \mathcal{D} (\mathbf{x},\tilde{\mathbf{x}}^T; \mathbf{w}_d) \right] + \mathbb{E}_{\mathbf{x},{\mathbf{z}}, T} \left[\log (1- \mathcal{D} (\mathbf{x}, \mathcal{G}(\mathbf{x},\mathbf{z}, {T}; \mathbf{w}_g); \mathbf{w}_d)) \right]}_{\textrm{conditional GAN loss}},
\end{aligned}
\end{equation}
\]

其中 \(\tilde{\mathbf{x}}^T\) 表示由扩散模型生成的基于目标风格的概念 \(T\) 的图像, \(\mathcal{G}\) 是具有原始权重 \(\mathbf{w}_g\) 和LoRA权重 \(\mathbf{w}_{lora}\) 的生成器, \(\mathcal{D}\) 表示由 \(\mathbf{w}_d\) 参数化的判别器函数, \(\mathbf{z}\) 是引入的随机噪声,以增加输出的随机性, \(\lambda\) 可用于调整两个损失项之间的相对重要性。

在微调过程中,生成器仅优化使用预测 \(\{\mathbf{w}_{n} \}_{n=1}^N\) 初始化的LoRA权重 \(\mathbf{w}_{lora}\) 。通过从预测中初始化GAN权重,能够使用更少的训练周期达到相同或更好的FID性能。除了在预测后进行微调外,还考虑将公式4中的GAN训练损失纳入公式1中的权重预测损失。然而,通过实验,论文发现将这两个损失项结合并不能提供更好的性能,反而增加了训练权重生成器的计算成本。

Experiments






如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

新思路,基于Diffusion的初始化权重生成策略 | ECCV'24的更多相关文章

  1. Hibernate入门之主键生成策略详解

    前言 上一节我们讲解了Hibernate命名策略,从本节我们开始陆续讲解属性.关系等映射,本节我们来讲讲主键的生成策略. 主键生成策略 JPA规范支持4种不同的主键生成策略(AUTO.IDENTITY ...

  2. 高并发环境下全局id生成策略

    解决方案: 基于Redis的全局id生成策略:(推荐此方法) 基于雪花算法的全局id生成: https://www.cnblogs.com/kobe-qi/p/8761690.html 基于zooke ...

  3. 基于按annotation的hibernate主键生成策略

    基于按annotation的hibernate主键生成策略 博客分类: Hibernate HibernateJavaJPAOracleMySQL  这里讨论代理主键,业务主键(比如说复合键等)这里不 ...

  4. jpa基于按annotation的hibernate主键生成策略

    JPA注解持久化类很方便,需要jar包:ejb3-persistence.jar下载 一.JPA通用策略生成器 通过annotation来映射hibernate实体的,基于annotation的hib ...

  5. 图解Janusgraph系列-分布式id生成策略分析

    JanusGraph - 分布式id的生成策略 大家好,我是洋仔,JanusGraph图解系列文章,实时更新~ 本次更新时间:2020-9-1 文章为作者跟踪源码和查看官方文档整理,如有任何问题,请联 ...

  6. (二)JPA 连接工厂、主键生成策略、DDL自动更新

    (一)JPA的快速入门 2.JPA连接工厂 通过之前的 代码 实现已经清楚的发现了整个的JPA实现步骤,但是这个步骤似乎有一些繁琐了,毕竟最终所关心的一定是EntityManager对象实例,而要想获 ...

  7. Hibernate(4)——主键生成策略、CRUD 基础API区别的总结 和 注解的使用

    俗话说,自己写的代码,6个月后也是别人的代码……复习!复习!复习!涉及的知识点总结如下: hibernate的主键生成策略 UUID 配置的补充:hbm2ddl.auto属性用法 注解还是配置文件 h ...

  8. hibernate主键生成策略

    在hibernate中,提供了多种主键生成器(不同的数据库,不同的表结构使用的主键生成策略也不相同),查阅相关资料经过实验总结如下: 1.increment 主键按照数值顺序递增,使用当前实例中最大值 ...

  9. Hibernate主键生成策略(转)

    1.自动增长identity 适用于MySQL.DB2.MS SQL  Server,采用数据库生成的主键,用于为long.short.int类型生成唯一标识 使用SQL Server 和 MySQL ...

  10. hibernate 注解 主键生成策略

    一.JPA通用策略生成器       通过annotation来映射hibernate实体的,基于annotation的hibernate主键标识为@Id, 其生成规则由@GeneratedValue ...

随机推荐

  1. 洛谷P1226 【模板】快速幂

    1.快速幂模板 前置知识 一个数字n,它的二进制位数一定是log2n向下取整+1: 快速幂模板代码 这段代码实现了快速幂算法(Exponentiation by squaring),用来计算 ( an ...

  2. 基于druid和spring的动态数据库以及读写分离 转

    spring与druid可以实现动态数据源,夸库查询,读写分离等功能.现在说一下配置: 1.需要配置多个spring数据源 spring-data.xml <!-- 动态数据源 --> & ...

  3. md2pdf

    https://www.pandoc.org/installing.html https://github.com/jgm/pandoc/releases/tag/2.18 https://blog. ...

  4. [APIO2019] 路灯 题解

    LG5445 把询问 \(x,y\) 看作平面上的点 记当前时刻 \(t\),\(l\) 是与 \(i\) 连通的最左端,\(r\) 是与 \(i+1\) 连通的最右端,可以通过 set 维护断边找到 ...

  5. C程序起点main函数

    C程序起点main函数 main c语言中main函数接收两个参数int argc, char* argv[] int main(int argc, char* argv[]); int main(i ...

  6. CF1693D--单调区间

    \(T_4\) 单调区间结题报告 题目描述 一句话题意:给定一个排列 \(a\) 算出有多少个区间 \([l , r]\) , 满足其可以划分为一个单调递增子序列和单调递减子序列,其中单调递增子序列长 ...

  7. Spring Boot 框架中配置文件 application.properties 当中的所有配置大全

    Spring Boot 框架中配置文件 application.properties 当中的所有配置大全 #SPRING CONFIG(ConfigFileApplicationListener) s ...

  8. Python 项目及依赖管理工具技术选型

    Python 项目及依赖管理工具,类似于 Java 中的 Maven 与 Node 中的 npm + webpack,在开发和维护项目时起着重要的作用.使用适当的依赖管理工具可以显著提高开发效率,减少 ...

  9. k8s pod挂载hostPath执行写时报错Permission denied

    关于hostPath的权限说明 最近项目中经常遇到pod中container挂载主机hostPath报错无权限问题: httpd@hostpath-volume:/test-volume$ touch ...

  10. 【YashanDB知识库】23.1.3.101版本创建物化视图coredump

    [标题]23.1.3.101版本创建物化视图coredump [问题分类]数据库错误 [关键词]YashanDB, 物化视图, coredump, dblink [问题描述]在23.1.3.101版本 ...