Song Y. and Ermon S. Generative modeling by estimating gradients of the data distribution. In Advances in Neural Information Processing Systems (NIPS), 2019.

当前生成模型, 要么依赖对抗损失(GAN), 要么依赖替代损失(VAE), 本文提出了基于score matching 训练, 以及利用annealed Langevin dynamics推断的模型, 思想非常有趣.

主要内容

Langevin dynamics

对于分布\(p(x)\), 我们可以通过下列方式迭代生成

\[\tilde{x}_t = \tilde{x}_{t-1} + \frac{\epsilon}{2} \nabla_x \log p (\tilde{x}_{t-1}) + \sqrt{\epsilon} z_t,
\]

其中\(\tilde{x}_0 \sim \pi(x)\)来自一个先验分布, \(z_t \sim \mathcal{N}(0, I)\). 当步长\(\epsilon \rightarrow 0\)并且\(T \rightarrow +\infty\)的时候, \(\tilde{x}_T\)可以认为是从\(p(x)\)中采样的样本.

注: 一般的Langevin, dynamics还需要在每一次迭代后计算一个接受概率然后判断是否接受, 不过在实际中这一步往往可以省略.

Score Matching

通过上述的迭代可以发现, 我们只需要获得\(\nabla_x \log p(x)\)即可采样\(x\), 我们可以期望通过下面的方式, 通过一个网络\(s_{\theta}(x)\)来逼近\(\nabla_x \log p_{data}(x)\):

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{p_{data}(x)} [\| s_{\theta} (x) - \nabla_x \log p_{data}(x) \|_2^2],
\]

但是在实际中, 先验\(\log p_{data}(x)\)也是未知的, 幸运的是上述公式等价于:

\[\min_{\theta} \: \mathbb{E}_{p_{data}(x)} [\mathrm{tr}(\nabla_x s_{\theta} (x)) + \frac{1}{2} \|s_{\theta}(x)\|_2^2].
\]

注: 见 score matching

Denoising Score Matching

一个共识是, 所获得的数据往往是一个低维流形, 即其内在的维度实际上很低. 所以\(\mathbb{E}_{p_{data}(x)}\)在实际中会出现高密度的区域估计得很好, 但是低密度得区域估计得非常差. Denosing Score Matching提高了一个较为鲁棒的替代方法:

\[\min_{\theta} \: \frac{1}{2} \mathbb{E}_{q_{\sigma}(\tilde{x}|x)p_{data}(x)} [\| s_{\theta} (\tilde{x}) - \nabla_x \log q_{\sigma}(\tilde{x}|x) \|_2^2].
\]

当优化得足够好的时候,

\[s_{\theta^*}(x) = \nabla_x \log q_{\sigma}(x), \: q_{\sigma}(\tilde{x}) := \int q_{\sigma}(\tilde{x}|x) p_{data}(x) \mathrm{d}x.
\]

实际中, 通常取\(q_{\sigma}(\tilde{x}|x) = \mathcal{N}(\tilde{x}|x, \sigma^2 I)\), 相当于在真实数据\(x\)上加了一个扰动, 当扰动足够小(\(\sigma\)足够小)的时候, \(q_{\sigma}(x) \approx p_{data}(x)\), 则\(s_{\theta^*}(x) \approx \nabla_x \log p_{data}(x)\).

注: 为啥期望部分要有\(p_{data}\)? 实际上上述目标和score matching依旧是等价的.

Noise Conditional Score Networks

Slow mixing of Langevin dynamics

假设\(p_{data}(x) = \pi p_1(x) + (1 - \pi)p_2(x)\), 且\(p_1, p_2\)的支撑集合是互斥的, 那么 \(\nabla_{x} \log p_{data}(x)\)要么为\(\nabla_{x} \log p_{1}(x)\)或者\(\nabla_{x} \log p_{2}(x)\), 与\(\pi\)没有丝毫关联, 这会导致训练的结果与\(\pi\)也没有关联. 在实际中, 若\(p_1, p_2\)近似互斥, 也会产生类似的情况:

如上图所示, 通过Langevin dynamics采样的点几乎是1:1的, 这与真实的分布便有了出入.

作者的想法是, 设计一个noise conditional score networks:

\[s_\theta(x, \sigma),
\]

给定不同的\(\sigma\)其拟合不同扰动大小的\(p_{\sigma}\), 在采样中, 首先用大一点的\(\sigma\), 然后再逐步缩小, 这便是一种退火的思想. 显然, 一开始用大一点的\(\sigma\)能够为后面的采样提供更好更鲁棒的初始点.

损失函数

设定\(\sigma_i, i=1,2,\cdots, L\), 且满足:

\[\frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1,
\]

即一个等比例(缩小)的数列.

对于每个\(\sigma\)采用如下损失:

\[\ell(\theta; \sigma) =
\frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{\mathcal{N}(\tilde{x}|x, \sigma I)} [\| s_{\theta} (\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \|_2^2].
\]

注: \(\nabla_{\tilde{x}} q_{\sigma}(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}\).

于是总损失为

\[\mathcal{L}(\theta; \{\sigma_i\}_{i=1}^L) := \frac{1}{L}\sum_{i=1}^L \lambda (\sigma_i)\ell(\theta;\sigma_i),
\]

\(\lambda(\sigma_i)\)为权重系数.

Annealed Langevin dynamics

Input: \(\{\sigma_i\}_{i=1}^L, \epsilon, T\);

  1. 初始化\(x_0\);
  2. For \(i=1,2,\cdots, L\) do:
    • \(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\);
    • For \(t=1,2,\cdots, T\) do:
      • 采样\(z_t \sim \mathcal{N}(0, I)\);
      • \(x_t \leftarrow x_{t-1} + \frac{\alpha_i}{2}s_{\theta}(x_{t-1}, \sigma) + \sqrt{\alpha_i} z_t\);
    • \(x_0 \leftarrow x_T\);

Output: \(x_T\).

细节

  1. 关于参数\(\lambda(\sigma)\)的选择:

    作者推荐选择\(\lambda(\sigma) = \sigma^2\), 因为当优化到最优的时候, \(\|s_{\theta}(x, \sigma)\|_2 \propto 1 / \sigma\), 故\(\sigma^2 \ell(\theta;\sigma) = \frac{1}{2}\mathbb{E}[\|\sigma s_{\theta}(x, \sigma) + \frac{\tilde{x} - x}{\sigma} \|_2^2]\), 其中\(\sigma s_{\theta}(x, \sigma) \propto 1, \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I)\), 故\(\sigma^2 \ell_{\theta,\sigma}\)与\(\sigma\)无关.

  2. 关于\(\alpha_i \leftarrow \epsilon \cdot \sigma_i^2 / \sigma_L^2\):

对于一次Langevin dynamic, 其获得的信息为: \(\frac{\alpha_i}{2} s_{\theta}(x_{t-1}, \sigma)\), 其噪声为\(\sqrt{\alpha_i}z_t\), 故其信噪比(signal-to-noise)为(应该是element-wise的计算?)

\[\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z},
\]

当我们按照算法中的取法时, 我们有

\[\begin{array}{ll}
\|\frac{\alpha_i s_{\theta}(x, \sigma_i)}{2 \sqrt{\alpha_i} z}\|_2^2
&\approx\frac{\alpha_i \| s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto\frac{\|\sigma_i s_{\theta}(x, \sigma_i)\|_2^2}{4} \\
&\propto \frac{1}{4}.
\end{array}
\]

故采用此策略能够保证SNR保持一个稳定的值.

代码

原文代码

Generative Modeling by Estimating Gradients of the Data Distribution的更多相关文章

  1. 泡泡一分钟:GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping

    张宁  GEN-SLAM - Generative Modeling for Monocular Simultaneous Localization and Mapping GEN-SLAM  - 单 ...

  2. Statistics : Data Distribution

    1.Normal distribution In probability theory, the normal (or Gaussian or Gauss or Laplace–Gauss) dist ...

  3. 论文笔记之:Generative Adversarial Nets

    Generative Adversarial Nets NIPS 2014  摘要:本文通过对抗过程,提出了一种新的框架来预测产生式模型,我们同时训练两个模型:一个产生式模型 G,该模型可以抓住数据分 ...

  4. [GAN] Generative networks

    中文版:https://zhuanlan.zhihu.com/p/27440393 原文版:https://www.oreilly.com/learning/generative-adversaria ...

  5. Generative model 和Discriminative model

    学习音乐自动标注过程中设计了有关分类型模型和生成型模型的东西,特地查了相关资料,在这里汇总. http://blog.sina.com.cn/s/blog_a18c98e50101058u.html ...

  6. 生成模型(Generative)和判别模型(Discriminative)

    生成模型(Generative)和判别模型(Discriminative) 引言    最近看文章<A survey of appearance models in visual object ...

  7. (转)Deep Learning Research Review Week 1: Generative Adversarial Nets

    Adit Deshpande CS Undergrad at UCLA ('19) Blog About Resume Deep Learning Research Review Week 1: Ge ...

  8. (转)Introductory guide to Generative Adversarial Networks (GANs) and their promise!

    Introductory guide to Generative Adversarial Networks (GANs) and their promise! Introduction Neural ...

  9. SalGAN: Visual saliency prediction with generative adversarial networks

    SalGAN: Visual saliency prediction with generative adversarial networks 2017-03-17 摘要:本文引入了对抗网络的对抗训练 ...

随机推荐

  1. 学习java的第八天

    一.今日收获 1.学习完全学习手册上2.3转义字符与2.4运算符两节 二.今日难题 1.没有什么难理解的问题 三.明日目标 1.哔哩哔哩教学视频 2.Java学习手册

  2. MapReduce03 框架原理InputFormat数据输入

    目录 1 InputFormat数据输入 1.1 切片与MapTask并行度决定机制 问题引出 MapTask并行度决定机制 Job提交流程源码 切片源码 1.2 FileInputFormat切片机 ...

  3. acre, across

    acre The acre is a unit of land area used in the imperial and US customary systems. It is traditiona ...

  4. Hbase(二)【shell操作】

    目录 一.基础操作 1.进入shell命令行 2.帮助查看命令 二.命名空间操作 1.创建namespace 2.查看namespace 3.删除命名空间 三.表操作 1.查看所有表 2.创建表 3. ...

  5. Linux学习 - 流程控制

    一.if语句 1 单分支if条件语句 (1) if  [ 条件判断式 ];then 程序  fi (2) if [ 条件判断式 ] then 程序  fi 例:检测根分区的使用量 2 双分支if条件语 ...

  6. OC简单介绍

    一.OC与C的对比 关键字 OC新增的关键字在使用时,注意部分关键字以"@"开头 方法->函数 定义与实现 数据类型 新增:BOOL/NSObject/id/SEL/bloc ...

  7. 使用Spring JDBC连接数据库(以SQL Server为例)

    一.配置Spring JDBC 1.导入相关jar包 (略) 2.配置文件applicationContext.xml <?xml version="1.0" encodin ...

  8. HTTPS及流程简析

    [序] 在我们在浏览某些网站的时候,有时候浏览器提示需要安装根证书,可是为什么浏览器会提示呢?估计一部分人想也没想就直接安装了,不求甚解不好吗? 那么什么是根证书呢?在大概的囫囵吞枣式的百度之后知道了 ...

  9. numpy基础教程--where函数的使用

    在numpy中,where函数是一个三元运算符,函数原型为where(condition, x, y),意思是当条件成立的时候,将矩阵的值设置为x,否则设置为y 一个很简单的应用就是,在一个矩阵当中, ...

  10. 3、Spring的DI依赖注入

    一.DI介绍 1.DI介绍 依赖注入,应用程序运行依赖的资源由Spring为其提供,资源进入应用程序的方式称为注入. Spring容器管理容器中Bean之间的依赖关系,Spring使用一种被称为&qu ...