Diffusion系列 - DDIM 公式推导 + 代码 -(三)
DENOISING DIFFUSION IMPLICIT MODELS (DDIM)
从DDPM中我们知道,其扩散过程(前向过程、或加噪过程)被定义为一个马尔可夫过程,其去噪过程(也有叫逆向过程)也是一个马尔可夫过程。对马尔可夫假设的依赖,导致重建每一步都需要依赖上一步的状态,所以推理需要较多的步长。
q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},{(1-\bar{\alpha}_t})I)
\]
q(x_{t-1}|x_t,x_0)
&\overset{Bayes}{=} \dfrac{q(x_t|x_{t-1},x_0)q(x_{t-1}|x_0)}{q(x_t|x_0)} \\
&\overset{Markov}{=} \dfrac{q(x_t|x_{t-1})q(x_{t-1}|x_0)}{q(x_t|x_0)}
\end{align*}
\]
DDPM中对于其逆向分布的建模使用马尔可夫假设,这样做的目的是将式子中的未知项 \(q(x_t|x_{t-1},x_0)\),转化成了已知项 \(q(x_t|x_{t-1})\),最后求出 \(q(x_{t-1}|x_t,x_0)\) 的分布也是一个高斯分布 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\)。
从DDPM的结论出发,我们不妨直接假设 \(q(x_{t-1}|x_t,x_0)\) 的分布为高斯分布,在不使用马尔可夫假设的情况下,尝试求解 \(q(x_{t-1}|x_t,x_0)\) 。
由 DDPM 中 \(q(x_{t-1}|x_t,x_0)\) 的分布 \(\mathcal{N}(x_{t-1};\mu_q(x_t,x_0),\Sigma_q(t))\) 可知,均值为 一个关于 \(x_t,x_0\) 的函数,方差为一个关于 \(t\) 的函数。
我们可以把 \(q(x_{t-1}|x_t,x_0)\) 设计成如下分布:
\]
这样,只要求解出 \(a,b,\sigma_t\) 这三个待定系数,即可确定 \(q(x_{t-1}|x_t,x_0)\) 的分布。
重参数化 \(q(x_{t-1}|x_t,x_0)\) :
\]
假设训练模型时输入噪声图片的加噪参数与DDPM完全一致
由 \(q(x_t|x_{0}) := \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_{0},(1-\bar{\alpha}_t)I)\) :
\]
代入 \(x_t\) 有:
x_{t-1} &=a x_0 + b(\sqrt{\bar{\alpha}_t}x_{0}+\sqrt{1-\bar{\alpha}_t}\varepsilon^{\prime}_{t}) + \sigma_t \varepsilon^{\prime}_{t-1} \\
&= (a + b\sqrt{\bar{\alpha}_t}) x_0 + (b\sqrt{1-\bar{\alpha}_t}\varepsilon^{\prime}_{t} + \sigma_t \varepsilon^{\prime}_{t-1}) \\
&= (a + b\sqrt{\bar{\alpha}_t}) x_0 + (\sqrt{b^2(1-\bar{\alpha}_t)+ \sigma_t^2}) \bar{\varepsilon}_{t-1}
\end{align*}
\]
又:
\]
观察系数可以得到方程组:
a + b\sqrt{\bar{\alpha}_t} = \sqrt{\bar{\alpha}_{t-1}} \\
\sqrt{b^2(1-\bar{\alpha}_t)+ \sigma_t^2} = \sqrt{1-\bar{\alpha}_{t-1}}
\end{cases}
\]
三个未知数 两个方程,可以用 \(\sigma_t\) 表示 \(a,b\):
a = \sqrt{\bar{\alpha}_{t-1}} - \sqrt{\bar{\alpha}_t} \sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}} \\
b = \sqrt{\dfrac{1-\bar{\alpha}_{t-1}-\sigma_t^2}{1-\bar{\alpha}_t}}
\end{cases}
\]
\(a, b\) 代入 \(q(x_{t-1}|x_t,x_0) := \mathcal{N}(x_{t-1}; a x_0 + b x_t,\sigma_t^2 I)\)
\]
又
x_0 = \dfrac{1}{\sqrt{\bar{\alpha}_t}}x_t - \dfrac{\sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \bar{\varepsilon}_0 \\
\]
代入 \(x_0\) 有:
\]
x_{t-1} &= \mu_q(x_t,x_0,t) + \sigma_t \varepsilon_0 \\
&= \sqrt{\bar{\alpha}_{t-1}} \underbrace{\dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \bar{\varepsilon}_0}{\sqrt{\bar{\alpha}_{t}}}}_{预测的x_0} + \underbrace{\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \bar{\varepsilon}_0}_{x_t的方向} + \underbrace{\sigma_t \varepsilon_0}_{随机噪声扰动}
\end{align*}
\]
通过观察 \(x_{t-1}\) 的分布,我们建模采样分布为高斯分布:
\]
并且均值和方差也采用相似的形式:
\mu_\theta(x_t,t) &= \sqrt{\bar{\alpha}_{t-1}} \dfrac{x_t-\sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t,t) }{\sqrt{\bar{\alpha}_{t}}} + \sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} \epsilon_\theta(x_t,t) \\
\Sigma_\theta(x_t,t) &= \sigma_t^2
\end{align*}
\]
其中 \(\epsilon_\theta(x_t,t)\) 为预测的噪声。
此时,确定优化目标只需要 \(q(x_{t-1}|x_t,x_0)\) 和 \(p_\theta(x_{t-1}|x_t)\) 两个分布尽可能相似,使用KL散度来度量,则有:
&\quad \ \underset{\theta}{argmin} D_{KL}(q(x_{t-1}|x_t,x_0)||p_\theta(x_{t-1}|x_t)) \\
&=\underset{\theta}{argmin} D_{KL}(\mathcal{N}(x_{t-1};\mu_q, \Sigma_q(t))||\mathcal{N}(x_{t-1};\mu_\theta, \Sigma_q(t))) \\
&=\underset{\theta}{argmin} \dfrac{1}{2} \left[ log\dfrac{|\Sigma_q(t)|}{|\Sigma_q(t)|} - k + tr(\Sigma_q(t)^{-1}\Sigma_q(t)) + (\mu_q-\mu_\theta)^T \Sigma_q(t)^{-1} (\mu_q-\mu_\theta) \right] \\
&=\underset{\theta}{argmin} \dfrac{1}{2} \left[ 0 - k + k + (\mu_q-\mu_\theta)^T (\sigma_t^2I)^{-1} (\mu_q-\mu_\theta) \right] \\
&\overset{内积公式A^TA}{=} \underset{\theta}{argmin} \dfrac{1}{2\sigma_t^2} \left[ ||\mu_q-\mu_\theta||_2^2 \right] \\
&\overset{代入\mu_q,\mu_\theta}{=} \underset{\theta}{argmin} \dfrac{1}{2\sigma_t^2} (\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2} - \dfrac{\sqrt{\bar{\alpha}_{t-1}} \sqrt{1-\bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}}) \left[ ||\bar{\varepsilon}_0-\epsilon_\theta(x_t,t)||_2^2 \right]
\end{align*}
\]
恰好与DDPM的优化目标一致,所以我们可以直接复用DDPM训练好的模型。
\(p_{\theta}\) 的采样步骤则为:
\]
令 \(\sigma_t=\eta \sqrt{\dfrac{(1-{\alpha}_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}}\)
当 \(\eta =1\) 时,前向过程为 Markovian ,采样过程变为 DDPM 。
当 \(\eta =0\) 时,采样过程为确定过程,此时的模型 称为 隐概率模型(implicit probabilstic model)。
DDIM如何加速采样:
在 DDPM 中,基于马尔可夫链 \(t\) 与 \(t-1\) 是相邻关系,例如 \(t=100\) 则 \(t-1=99\);
在 DDIM 中,\(t\) 与 \(t-1\) 只表示前后关系,例如 \(t=100\) 时,\(t-1\) 可以是 90 也可以是 80、70,只需保证 \(t-1 < t\) 即可。
此时构建的采样子序列 \(\tau=[\tau_i,\tau_{i-1},\cdots,\tau_{1}] \ll [t,t-1,\cdots,1]\) 。
例如,原序列 \(\Tau=[100,99,98,\cdots,1]\),采样子序列为 \(\tau=[100,90,80,\cdots,1]\) 。
DDIM 采样公式为:
\]
当 \(\eta= 0\) 时,DDIM 采样公式为:
\]
代码实现
训练过程与 DDPM 一致,代码参考上一篇文章。采样代码如下:
device = 'cuda'
torch.cuda.empty_cache()
model = Unet().to(device)
model.load_state_dict(torch.load('ddpm_T1000_l2_epochs_300.pth'))
model.eval()
image_size=96
epochs = 500
batch_size = 128
T=1000
betas = torch.linspace(0.0001, 0.02, T).to('cuda') # torch.Size([1000])
# 每隔20采样一次
tau_index = list(reversed(range(0, T, 20))) #[980, 960, ..., 20, 0]
eta = 0.003
# train
alphas = 1 - betas # 0.9999 -> 0.98
alphas_cumprod = torch.cumprod(alphas, axis=0) # 0.9999 -> 0.0000
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1-alphas_cumprod)
def get_val_by_index(val, t, x_shape):
batch_t = t.shape[0]
out = val.gather(-1, t)
return out.reshape(batch_t, *((1,) * (len(x_shape) - 1))) # torch.Size([batch_t, 1, 1, 1])
def p_sample_ddim(model):
def step_denoise(model, x_tau_i, tau_i, tau_i_1):
sqrt_alphas_bar_tau_i = get_val_by_index(sqrt_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_alphas_bar_tau_i_1 = get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)
denoise = model(x_tau_i, tau_i)
if eta == 0:
sqrt_1_minus_alphas_bar_tau_i = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape)
sqrt_1_minus_alphas_bar_tau_i_1 = get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i_1, x_tau_i.shape)
x_tau_i_1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * x_tau_i \
+ (sqrt_1_minus_alphas_bar_tau_i_1 - sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * sqrt_1_minus_alphas_bar_tau_i) \
* denoise
return x_tau_i_1
sigma = eta * torch.sqrt((1-get_val_by_index(alphas, tau_i, x_tau_i.shape)) * \
(1-get_val_by_index(sqrt_alphas_cumprod, tau_i_1, x_tau_i.shape)) / get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape))
noise_z = torch.randn_like(x_tau_i, device=x_tau_i.device)
# 整个式子由三部分组成
c1 = sqrt_alphas_bar_tau_i_1 / sqrt_alphas_bar_tau_i * (x_tau_i - get_val_by_index(sqrt_one_minus_alphas_cumprod, tau_i, x_tau_i.shape) * denoise)
c2 = torch.sqrt(1 - get_val_by_index(alphas_cumprod, tau_i_1, x_tau_i.shape) - sigma) * denoise
c3 = sigma * noise_z
x_tau_i_1 = c1 + c2 + c3
return x_tau_i_1
img_pred = torch.randn((4, 3, image_size, image_size), device=device)
for k in range(0, len(tau_index)):
# print(tau_index)
# 因为 tau_index 是倒序的,tau_i = k, tau_i_1 = k+1,这里不能弄反
tau_i_1 = torch.tensor([tau_index[k+1]], device=device, dtype=torch.long)
tau_i = torch.tensor([tau_index[k]], device=device, dtype=torch.long)
img_pred = step_denoise(model, img_pred, tau_i, tau_i_1)
torch.cuda.empty_cache()
if tau_index[k+1] == 0: return img_pred
return img_pred
with torch.no_grad():
img = p_sample_ddim(model)
img = torch.clamp(img, -1.0, 1.0)
show_img_batch(img.detach().cpu())
DDIM
https://arxiv.org/pdf/2010.02502
https://github.com/ermongroup/ddim
Diffusion系列 - DDIM 公式推导 + 代码 -(三)的更多相关文章
- Android系列之网络(三)----使用HttpClient发送HTTP请求(分别通过GET和POST方法发送数据)
[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/ ...
- Android系列之Fragment(三)----Fragment和Activity之间的通信(含接口回调)
[声明] 欢迎转载,但请保留文章原始出处→_→ 生命壹号:http://www.cnblogs.com/smyhvae/ 文章来源:http://www.cnblogs.com/smyhvae/p/ ...
- ReactiveSwift源码解析(九) SignalProducerProtocol延展中的Start、Lift系列方法的代码实现
上篇博客我们聊完SignalProducer结构体的基本实现后,我们接下来就聊一下SignalProducerProtocol延展中的start和lift系列方法.SignalProducer结构体的 ...
- JavaScript 系列博客(三)
JavaScript 系列博客(三) 前言 本篇介绍 JavaScript 中的函数知识. 函数的三种声明方法 function 命令 可以类比为 python 中的 def 关键词. functio ...
- 【原创 深度学习与TensorFlow 动手实践系列 - 3】第三课:卷积神经网络 - 基础篇
[原创 深度学习与TensorFlow 动手实践系列 - 3]第三课:卷积神经网络 - 基础篇 提纲: 1. 链式反向梯度传到 2. 卷积神经网络 - 卷积层 3. 卷积神经网络 - 功能层 4. 实 ...
- Linux Shell系列教程之(三)Shell变量
本文是Linux Shell系列教程的第(三)篇,更多shell教程请看:Linux Shell系列教程 Shell作为一种高级的脚本类语言,也是支持自定义变量的.今天就为大家介绍下Shell中的变量 ...
- 【HANA系列】【第三篇】SAP HANA XS的JavaScript安全事项
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[HANA系列][第三篇]SAP HANA XS ...
- 孟老板 ListAdapter封装, 告别Adapter代码 (三)
BaseAdapter系列 ListAdapter封装, 告别Adapter代码 (一) ListAdapter封装, 告别Adapter代码 (二) ListAdapter封装, 告别Adapter ...
- 《手把手教你》系列技巧篇(三十)-java+ selenium自动化测试- Actions的相关操作下篇(详解教程)
1.简介 本文主要介绍两个在测试过程中可能会用到的功能:Actions类中的拖拽操作和Actions类中的划取字段操作.例如:需要在一堆log字符中随机划取一段文字,然后右键选择摘取功能. 2.拖拽操 ...
- 《手把手教你》系列技巧篇(三十一)-java+ selenium自动化测试- Actions的相关操作-番外篇(详解教程)
1.简介 上一篇中,宏哥说的宏哥在最后提到网站的反爬虫机制,那么宏哥在自己本地做一个网页,没有那个反爬虫的机制,谷歌浏览器是不是就可以验证成功了,宏哥就想验证一下自己想法,于是写了这一篇文章,另外也是 ...
随机推荐
- 文件系统(十一):Linux Squashfs只读文件系统介绍
liwen01 2024.07.21 前言 嵌入式Linux系统中,squashfs文件系统使用非常广泛.它主要的特性是只读,文件压缩比例高.对于flash空间紧张的系统,可以将一些不需要修改的资源打 ...
- 【Java】SPI机制
SPI全称: 服务供应商接口 Service Provider Interface 服务发现机制 入门概念视频来自于: https://www.bilibili.com/video/BV1E44y1N ...
- 【JDBC】Extra02 SqlServer-JDBC
官网驱动获取地址: https://www.microsoft.com/zh-cn/download/details.aspx Maven仓库获取: https://mvnrepository.com ...
- 目前国内全地形能力最强的双足机器人 —— 逐际动力 —— 提出迭代式预训练(Iterative Pre-training)方法的强化学习算法
相关: https://weibo.com/1255595687/O5k4Aj8l2 该公司对其产品的强化学习训练算法给出了较少的描述: 提出迭代式预训练(Iterative Pre-training ...
- 如何使用二阶优化算法实现对神经网络的优化 —— 分布式计算的近似二阶优化算法实现对神经网络的优化 —— 《Distributed Hessian-Free Optimization for Deep Neural Network》
论文: <Distributed Hessian-Free Optimization for Deep Neural Network> 地址: https://arxiv.org/abs/ ...
- (续)在深度计算框架MindSpore中如何对不持续的计算进行处理——对数据集进行一定epoch数量的训练后,进行其他工作处理,再返回来接着进行一定epoch数量的训练——单步计算
内容接前文: https://www.cnblogs.com/devilmaycry812839668/p/14988686.html 这里我们考虑的数据集是自建数据集,那么效果又会如何呢??? im ...
- 构建无服务器数仓(三 )EMR Serverless 操作要点、优化以及开放集成测试
引言 在数据驱动的世界中,企业正在寻求可靠且高性能的解决方案来管理其不断增长的数据需求.本系列博客从一个重视数据安全和合规性的 B2C 金融科技客户的角度来讨论云上云下混合部署的情况下如何利用亚马逊云 ...
- 推荐一款Python开源移动应用安全测试分析工具!!!
今天给大家推荐一个安全测试相关的开源项目:nccgroup/house 1.介绍 它是一个由 NCC Group 开发的,一个基于Frida和Python编写的动态运行时移动应用分析工具包,提供了基于 ...
- quartz监控日志(二)添加监听器
上一章介绍监控job有三种方案,其实还有一个简单方案是实现quartz的TriggerListener. 上次我也试了这个方案,但是由于操作错误,导致没有监控成功,所以才选择分析源码来实现代理进行监控 ...
- 22张图详解浏览器请求数据包如何到达web服务器(搞懂网络可以毕业了)
浏览器的请求数据包如何到达web服务器? 很多读者对于其中的完整流程不是特别的了解,下面一口君通过这22张图,详细的讲解我们点击浏览器的网址之后,数据包是如何经过重重险阻到达web server的. ...