Gibbs Sampling

Intro

Gibbs Sampling 方法是我最近在看概率图模型相关的论文的时候遇见的,采样方法大致为:迭代抽样,最开始从随机样本中抽样,然后将此样本作为条件项,按条件概率抽样,每次只从一个维度考虑,当所有维度均采样完,开始下一轮迭代。

Random Sampling

假设我们一直一个随机变量的概率密度函数,我们如何采样得到服从这个分布的样本呢?

学矩阵论的时候,老师教我们用反函数来生成任意概率分布的随机数,因此,我们也可以用反函数法来生成该分布的样本。即假设 $ \xi $ 是 $ [0,1] $ 区间上均匀分布的随机变量,则其反函数$ cdf^{-1}( \xi ) $ 服从该概率密度函数为 $ p(x) $ 的分布。

有一个问题就是,当 $ p(x) $ 复杂到其累积分布函数的反函数无法计算的时候,或者不知道 $ p(x) $ 的精确值的时候,如何采样呢?

这时候就要用到一些采样的策略,比如拒绝采样、重要性采样、Gibbs采样等等。下面就记一下各种采样策略。

Rejection Sampling

拒绝采样的原理是,已知一个提议分布q(往往是简单分布)和原始分布p,从提议分布中采样一个样本\(\hat{x}\),然后计算接受率\(a(\hat{x}) = \frac{p(\hat{x}}{kq(\hat{x})}\),然后从均匀分布中生成一个值z,如果z小于等于a,则接受样本,否则不接受样本,继续采样,知道采样到了足够的样本。

这个图应该可以说明,上面蓝色的线是提议分布,必须包含原始分布,然后在z0处计算接受率即可。

然而拒绝采样要求提议分布和原始分布比较接近,这样采样率才会比较高,否则这个采样方法就是低效的,所以往往实际中并不采用这种采样方法。同样的,重要性采样方法也是比较低效的方法。(略去)

MCMC

MCMC是马尔可夫蒙特卡罗方法,是一种针对高维变量的采样方法。

MCMC的核心思想是将采样过程看成一个马尔可夫链,认为第t+1次采样是依赖于第t次抽取样本\(x_t\)以及状态转移分布\(q(x|x_t)\)。根据马尔可夫性链的收敛特性,我们知道在转移足够多此之后最终的状态将会收敛到一个固定的状态,我们假定收敛时的分布为\(p(x)\),那么在状态平稳时进行抽样得到的样本就肯定服从与\(p(x)\)分布。

MCMC一般应用的方法有Metropolis-Hastings算法和Gibbs采样算法。为了快点引入Gibbs Sampling,前者略去。

Gibbs Sampling

假设有一随机向量\(x = (x_1,x_2,...,x_d)\),其中d表示他有d维,每一维是一随机变量,且并不是我们常见的相互独立前提。那么,如果我们已知这个随机向量的概率分布,我们如何从这个分布中进行采样呢?

显然想要从多元分布的联合概率分布中直接抽样是相当困难的,而Gibbs Sampling就是一种简单而且有效的采样方法。吉布斯采样的大致步骤如下:

从一个随机的初始化状态\(x^{(0)}=[x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}]\)开始,对每个维度单独进行采样,其采样顺序大致如下:

\[x_1^{(1)} \thicksim p(x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\x_2^{(1)} \thicksim p(x_2|x_1^{(0)},x_3^{(0)},\cdots,x_d^{(0)}) \\\vdots \\x_d^{(1)} \thicksim p(x_d|x_1^{(0)},x_2^{(0)},\cdots,x_{d-1}^{(0)}) \\\vdots \\x_1^{(t)} \thicksim p(x_1|x_2^{(t-1)},x_3^{(t-1)},\cdots,x_d^{(t-1)}) \\\vdots\\x_{d}^{(t)} \thicksim p(x_d|x_1^{(t-1)},x_2^{(t-1)},\cdots,x_{d-1}^{(t-1)}) \\
\]

遵从上面的采样步骤,我们最终能够采样得到所需要的高维分布的样本。需要注意的是,迭代的最开始采样得到的样本并不是完全满足所需要的分布的样本,因为采样之初采样的分布是提议分布,一般是均匀分布,而Gibbs Sampling的过程更像是一个单步迭代的过程,这使我想起了EM算法,都是一样的,一步一步去迭代达到最终结果。

我在网上找到了一个能够描述这个过程的图片:

如上图所示,右图是我们需要的分布,左边是迭代的过程,最开始抽样的点0和1都是均匀分布抽样得到的,而越到后面,抽样的点都越满足我们右边的分布,所以这个过程可以说明Gibbs Sampling抽样的过程是可行的。

还有下面这张图,也差不多:

Coding

Gibbs Sampling我是从一篇图像合成的论文中看到并有所了解的,文章基于MRF,使用神经网络去拟合条件分布\(p(x_i|x_{-i})\),其中\(x_{-i}\)表示除了第i个属性的其他属性。

具体到图像中来,\(x_i\)就是第i个位置的像素点的像素值,而\(x_{-i}\)描述的就是除了这个点以外的其他所有点,因此上式的概率分布就是一个条件分布。

使用神经网络可以拟合出这个分布来,那么如何去生成图片又是一个问题。

文章给出的解决方案就是Gibbs Sampling,先从随机噪声开始,逐像素进行生成,第一次迭代完成将生成一张图片,那么第二次第三次依次可以使用上一次迭代完前生成的图片进行迭代生成下一次,当迭代次数足够多的时候,即我们认为达到了平稳分布,这个时候生成的图片就是服从该分布的图片了。

原文参见:

原文链接

具体的,我给出下面的代码:

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from PIL import Image
import glob
import random
import cv2 as cv
class MConv(nn.Conv2d):
'''
mask_type A or B
A : the center is zero
B : the center is not zero
'''
def __init__(self,mask_type,*args,**kwargs):
super(MConv,self).__init__(*args,**kwargs)
assert mask_type in ["A","B"]
self.mask_type = mask_type
self.register_buffer('mask', self.weight.data.clone())
_,_,h,w = self.weight.size()
self.mask.fill_(1)
self.mask[:,:,h//2,w//2 + (mask_type == 'B'):] = 0
self.mask[:,:,h//2+1:,:] = 0 def forward(self,x):
self.weight.data *= self.mask
return super(MaskedConv2d,self).forward(x) class DoublePixelCNN(nn.Module):
def __init__(self,fm,kernel_size = 7,padding = 3):
super(DoublePixelCNN, self).__init__()
self.net1 = nn.Sequential(
MConv('A', 1, 64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
#nn.Conv2d(fm, 256, 1)
)
self.net2 = nn.Sequential(
MConv('A', 1, 64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
#nn.Conv2d(fm, 256, 1)
) self.conv1x1 = nn.Conv2d(fm*2, 256, 1)
def forward(self,x):
x1 = self.net1(x)
x2 = self.net2(x.flip(dims = [-1,-2]))
x = torch.cat([x1,x2.flip(dims = [-1,-2])],dim = 1)
x = self.conv1x1(x)
return x if __name__ == "__main__":
tr = data.DataLoader(datasets.MNIST(root="/media/xueaoru/Ubuntu/dataset/data",transform=transforms.ToTensor(),),
batch_size=64, shuffle=True, num_workers=12, pin_memory=True)
net = DoublePixelCNN(128)
net.cuda()
sample = torch.rand(64,1,k,k).cuda()
optimizer = optim.Adam(net.parameters(),lr = 0.0001)
for epoch in range(1000):
net.train()
running_loss = 0.
for input,_ in tqdm(tr):
#print(input.size())
input = input.cuda()
#target = target.cuda()
target = (input.data[:,:] * 255).long() # (b,3,h,w)
# net(input) (b,256,3,h,w)
loss = F.cross_entropy(net(input), target) # 计算的是每个像素的二分类交叉熵
running_loss += loss.item() optimizer.zero_grad()
loss.backward()
optimizer.step()
print("training loss: {:.8f}".format(running_loss / len(tr)))
if epoch % 5 == 0:
torch.save(net.state_dict(),open("./{}.pth".format(epoch),"wb"))
#sample.fill_(0)
net.eval()
with torch.no_grad():
for t in tqdm(range(300)):
for i in range(k):
for j in range(k):
out = net(sample) # (b,256)
probs = F.softmax(out[:, :, i ,j],dim = 1).data # (b,c) = (16,256)
sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255. utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)
sample = torch.rand(64,1,k,k).cuda()

由于这个方法采样时间极其缓慢,所以我生成的图片尺度比较小,训练周期也比较短,只是做个demo使用。

[学习笔记] Gibbs Sampling的更多相关文章

  1. PRML读书会第十一章 Sampling Methods(MCMC, Markov Chain Monte Carlo,细致平稳条件,Metropolis-Hastings,Gibbs Sampling,Slice Sampling,Hamiltonian MCMC)

    主讲人 网络上的尼采 (新浪微博: @Nietzsche_复杂网络机器学习) 网络上的尼采(813394698) 9:05:00  今天的主要内容:Markov Chain Monte Carlo,M ...

  2. 随机采样方法整理与讲解(MCMC、Gibbs Sampling等)

    本文是对参考资料中多篇关于sampling的内容进行总结+搬运,方便以后自己翻阅.其实参考资料中的资料写的比我好,大家可以看一下!好东西多分享!PRML的第11章也是sampling,有时间后面写到P ...

  3. Deep Learning(深度学习)学习笔记整理系列之(六)

    Deep Learning(深度学习)学习笔记整理系列 zouxy09@qq.com http://blog.csdn.net/zouxy09 作者:Zouxy version 1.0 2013-04 ...

  4. NLP︱高级词向量表达(二)——FastText(简述、学习笔记)

    FastText是Facebook开发的一款快速文本分类器,提供简单而高效的文本分类和表征学习的方法,不过这个项目其实是有两部分组成的,一部分是这篇文章介绍的 fastText 文本分类(paper: ...

  5. 随机采样和随机模拟:吉布斯采样Gibbs Sampling

    http://blog.csdn.net/pipisorry/article/details/51373090 吉布斯采样算法详解 为什么要用吉布斯采样 通俗解释一下什么是sampling. samp ...

  6. Deep learning with Python 学习笔记(10)

    生成式深度学习 机器学习模型能够对图像.音乐和故事的统计潜在空间(latent space)进行学习,然后从这个空间中采样(sample),创造出与模型在训练数据中所见到的艺术作品具有相似特征的新作品 ...

  7. 转 :hlda文献学习笔记

    David M.BLEI nCR文献学习笔记(基本完成了)  http://yhbys.blog.sohu.com/238343705.html 题目:The Nested Chinese Resta ...

  8. 深度学习方法:受限玻尔兹曼机RBM(三)模型求解,Gibbs sampling

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 接下来重点讲一下RBM模型求解 ...

  9. 机器学习方法(八):随机采样方法整理(MCMC、Gibbs Sampling等)

    转载请注明出处:Bin的专栏,http://blog.csdn.net/xbinworld 本文是对参考资料中多篇关于sampling的内容进行总结+搬运,方便以后自己翻阅.其实参考资料中的资料写的比 ...

随机推荐

  1. 11.jQuery之淡入淡出效果

    知识点:fadeIn   fadeOut  fadeToggle  fadeTo <style> div { width: 150px; height: 300px; background ...

  2. sql删除重复行和删除字段首位

    删除重复行 user_info: -- 单字段筛选重复行 SELECT *from user_info WHERE NAME in ( SELECT NAME from user_info GROUP ...

  3. SSH自动登录config文件配置

    title: SSH自动登录config文件配置 comments: false date: 2019-08-19 19:29:13 description: 更方便的 ssh 操作??? categ ...

  4. 玩转Android状态栏

    前言 前段时间,突然收到一个状态栏颜色优化设计的任务,将原本应用整体的黑色状态栏修改为根据标题栏颜色进行沉浸式设计,显示效果如下:   image 经过分析及踩过N多坑,终于完成了APP全局的修改.现 ...

  5. 配置ShiroFilter需要注意的问题(Shiro_DelegatingFilterProxy)

    ShiroFilter的工作原理 ShiroFilter:DelegatingFilterProxy作用是自动到Spring 容器查找名字为shiroFilter(filter-name)的bean并 ...

  6. 神奇的AI:将静态图片转为3D动图

    近日我们从外媒获得消息,位于莫斯科的三星AI中心和Skolkovo科学技术研究所的研究人员发表了一篇新论文,详细介绍了从单个静止人像照片生成3D动画人像的创建.与此前能够生成照片般逼真肖像的人工智能A ...

  7. PAT Advanced 1050 String Subtraction (20 分)

    Given two strings S​1​​ and S​2​​, S=S​1​​−S​2​​ is defined to be the remaining string after taking ...

  8. 2018牛客网暑期ACM多校训练营(第十场)J Rikka with Nickname(二分,字符串)

    链接:https://ac.nowcoder.com/acm/contest/148/J?&headNav=acm 来源:牛客网 Rikka with Nickname 时间限制:C/C++ ...

  9. ui自动化之selenium操作(四)简单元素操作

    1. clear() clear()方法用于清除文本输入框内的内容:一般输入框中都有默认文字,如果不清空有可能会导致字符拼接: browser.find_element(By.ID,"use ...

  10. Hadoop 开发环境虚拟机搭建

    软件下载: VMware软件: 链接:https://pan.baidu.com/s/1gWinLJpfWdAQ8AyEkZxpfg 密码:i2ap 下载Ubuntu 镜像文件; 链接:https:/ ...