[学习笔记] Gibbs Sampling
Gibbs Sampling
Intro
Gibbs Sampling 方法是我最近在看概率图模型相关的论文的时候遇见的,采样方法大致为:迭代抽样,最开始从随机样本中抽样,然后将此样本作为条件项,按条件概率抽样,每次只从一个维度考虑,当所有维度均采样完,开始下一轮迭代。
Random Sampling
假设我们一直一个随机变量的概率密度函数,我们如何采样得到服从这个分布的样本呢?
学矩阵论的时候,老师教我们用反函数来生成任意概率分布的随机数,因此,我们也可以用反函数法来生成该分布的样本。即假设 $ \xi $ 是 $ [0,1] $ 区间上均匀分布的随机变量,则其反函数$ cdf^{-1}( \xi ) $ 服从该概率密度函数为 $ p(x) $ 的分布。
有一个问题就是,当 $ p(x) $ 复杂到其累积分布函数的反函数无法计算的时候,或者不知道 $ p(x) $ 的精确值的时候,如何采样呢?
这时候就要用到一些采样的策略,比如拒绝采样、重要性采样、Gibbs采样等等。下面就记一下各种采样策略。
Rejection Sampling
拒绝采样的原理是,已知一个提议分布q(往往是简单分布)和原始分布p,从提议分布中采样一个样本\(\hat{x}\),然后计算接受率\(a(\hat{x}) = \frac{p(\hat{x}}{kq(\hat{x})}\),然后从均匀分布中生成一个值z,如果z小于等于a,则接受样本,否则不接受样本,继续采样,知道采样到了足够的样本。
这个图应该可以说明,上面蓝色的线是提议分布,必须包含原始分布,然后在z0处计算接受率即可。
然而拒绝采样要求提议分布和原始分布比较接近,这样采样率才会比较高,否则这个采样方法就是低效的,所以往往实际中并不采用这种采样方法。同样的,重要性采样方法也是比较低效的方法。(略去)
MCMC
MCMC是马尔可夫蒙特卡罗方法,是一种针对高维变量的采样方法。
MCMC的核心思想是将采样过程看成一个马尔可夫链,认为第t+1次采样是依赖于第t次抽取样本\(x_t\)以及状态转移分布\(q(x|x_t)\)。根据马尔可夫性链的收敛特性,我们知道在转移足够多此之后最终的状态将会收敛到一个固定的状态,我们假定收敛时的分布为\(p(x)\),那么在状态平稳时进行抽样得到的样本就肯定服从与\(p(x)\)分布。
MCMC一般应用的方法有Metropolis-Hastings算法和Gibbs采样算法。为了快点引入Gibbs Sampling,前者略去。
Gibbs Sampling
假设有一随机向量\(x = (x_1,x_2,...,x_d)\),其中d表示他有d维,每一维是一随机变量,且并不是我们常见的相互独立前提。那么,如果我们已知这个随机向量的概率分布,我们如何从这个分布中进行采样呢?
显然想要从多元分布的联合概率分布中直接抽样是相当困难的,而Gibbs Sampling就是一种简单而且有效的采样方法。吉布斯采样的大致步骤如下:
从一个随机的初始化状态\(x^{(0)}=[x_1|x_2^{(0)},x_3^{(0)},\cdots,x_d^{(0)}]\)开始,对每个维度单独进行采样,其采样顺序大致如下:
\]
遵从上面的采样步骤,我们最终能够采样得到所需要的高维分布的样本。需要注意的是,迭代的最开始采样得到的样本并不是完全满足所需要的分布的样本,因为采样之初采样的分布是提议分布,一般是均匀分布,而Gibbs Sampling的过程更像是一个单步迭代的过程,这使我想起了EM算法,都是一样的,一步一步去迭代达到最终结果。
我在网上找到了一个能够描述这个过程的图片:
如上图所示,右图是我们需要的分布,左边是迭代的过程,最开始抽样的点0和1都是均匀分布抽样得到的,而越到后面,抽样的点都越满足我们右边的分布,所以这个过程可以说明Gibbs Sampling抽样的过程是可行的。
还有下面这张图,也差不多:
Coding
Gibbs Sampling我是从一篇图像合成的论文中看到并有所了解的,文章基于MRF,使用神经网络去拟合条件分布\(p(x_i|x_{-i})\),其中\(x_{-i}\)表示除了第i个属性的其他属性。
具体到图像中来,\(x_i\)就是第i个位置的像素点的像素值,而\(x_{-i}\)描述的就是除了这个点以外的其他所有点,因此上式的概率分布就是一个条件分布。
使用神经网络可以拟合出这个分布来,那么如何去生成图片又是一个问题。
文章给出的解决方案就是Gibbs Sampling,先从随机噪声开始,逐像素进行生成,第一次迭代完成将生成一张图片,那么第二次第三次依次可以使用上一次迭代完前生成的图片进行迭代生成下一次,当迭代次数足够多的时候,即我们认为达到了平稳分布,这个时候生成的图片就是服从该分布的图片了。
原文参见:
具体的,我给出下面的代码:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from PIL import Image
import glob
import random
import cv2 as cv
class MConv(nn.Conv2d):
'''
mask_type A or B
A : the center is zero
B : the center is not zero
'''
def __init__(self,mask_type,*args,**kwargs):
super(MConv,self).__init__(*args,**kwargs)
assert mask_type in ["A","B"]
self.mask_type = mask_type
self.register_buffer('mask', self.weight.data.clone())
_,_,h,w = self.weight.size()
self.mask.fill_(1)
self.mask[:,:,h//2,w//2 + (mask_type == 'B'):] = 0
self.mask[:,:,h//2+1:,:] = 0
def forward(self,x):
self.weight.data *= self.mask
return super(MaskedConv2d,self).forward(x)
class DoublePixelCNN(nn.Module):
def __init__(self,fm,kernel_size = 7,padding = 3):
super(DoublePixelCNN, self).__init__()
self.net1 = nn.Sequential(
MConv('A', 1, 64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
#nn.Conv2d(fm, 256, 1)
)
self.net2 = nn.Sequential(
MConv('A', 1, 64, 17, 1,8, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
MConv('B', 64, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
MConv('B', fm, fm, kernel_size, 1, padding, bias=False), nn.BatchNorm2d(fm), nn.ReLU(True),
#nn.Conv2d(fm, 256, 1)
)
self.conv1x1 = nn.Conv2d(fm*2, 256, 1)
def forward(self,x):
x1 = self.net1(x)
x2 = self.net2(x.flip(dims = [-1,-2]))
x = torch.cat([x1,x2.flip(dims = [-1,-2])],dim = 1)
x = self.conv1x1(x)
return x
if __name__ == "__main__":
tr = data.DataLoader(datasets.MNIST(root="/media/xueaoru/Ubuntu/dataset/data",transform=transforms.ToTensor(),),
batch_size=64, shuffle=True, num_workers=12, pin_memory=True)
net = DoublePixelCNN(128)
net.cuda()
sample = torch.rand(64,1,k,k).cuda()
optimizer = optim.Adam(net.parameters(),lr = 0.0001)
for epoch in range(1000):
net.train()
running_loss = 0.
for input,_ in tqdm(tr):
#print(input.size())
input = input.cuda()
#target = target.cuda()
target = (input.data[:,:] * 255).long() # (b,3,h,w)
# net(input) (b,256,3,h,w)
loss = F.cross_entropy(net(input), target) # 计算的是每个像素的二分类交叉熵
running_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("training loss: {:.8f}".format(running_loss / len(tr)))
if epoch % 5 == 0:
torch.save(net.state_dict(),open("./{}.pth".format(epoch),"wb"))
#sample.fill_(0)
net.eval()
with torch.no_grad():
for t in tqdm(range(300)):
for i in range(k):
for j in range(k):
out = net(sample) # (b,256)
probs = F.softmax(out[:, :, i ,j],dim = 1).data # (b,c) = (16,256)
sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)
sample = torch.rand(64,1,k,k).cuda()
由于这个方法采样时间极其缓慢,所以我生成的图片尺度比较小,训练周期也比较短,只是做个demo使用。
[学习笔记] Gibbs Sampling的更多相关文章
- PRML读书会第十一章 Sampling Methods(MCMC, Markov Chain Monte Carlo,细致平稳条件,Metropolis-Hastings,Gibbs Sampling,Slice Sampling,Hamiltonian MCMC)
主讲人 网络上的尼采 (新浪微博: @Nietzsche_复杂网络机器学习) 网络上的尼采(813394698) 9:05:00 今天的主要内容:Markov Chain Monte Carlo,M ...
- 随机采样方法整理与讲解(MCMC、Gibbs Sampling等)
本文是对参考资料中多篇关于sampling的内容进行总结+搬运,方便以后自己翻阅.其实参考资料中的资料写的比我好,大家可以看一下!好东西多分享!PRML的第11章也是sampling,有时间后面写到P ...
- Deep Learning(深度学习)学习笔记整理系列之(六)
Deep Learning(深度学习)学习笔记整理系列 zouxy09@qq.com http://blog.csdn.net/zouxy09 作者:Zouxy version 1.0 2013-04 ...
- NLP︱高级词向量表达(二)——FastText(简述、学习笔记)
FastText是Facebook开发的一款快速文本分类器,提供简单而高效的文本分类和表征学习的方法,不过这个项目其实是有两部分组成的,一部分是这篇文章介绍的 fastText 文本分类(paper: ...
- 随机采样和随机模拟:吉布斯采样Gibbs Sampling
http://blog.csdn.net/pipisorry/article/details/51373090 吉布斯采样算法详解 为什么要用吉布斯采样 通俗解释一下什么是sampling. samp ...
- Deep learning with Python 学习笔记(10)
生成式深度学习 机器学习模型能够对图像.音乐和故事的统计潜在空间(latent space)进行学习,然后从这个空间中采样(sample),创造出与模型在训练数据中所见到的艺术作品具有相似特征的新作品 ...
- 转 :hlda文献学习笔记
David M.BLEI nCR文献学习笔记(基本完成了) http://yhbys.blog.sohu.com/238343705.html 题目:The Nested Chinese Resta ...
- 深度学习方法:受限玻尔兹曼机RBM(三)模型求解,Gibbs sampling
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 接下来重点讲一下RBM模型求解 ...
- 机器学习方法(八):随机采样方法整理(MCMC、Gibbs Sampling等)
转载请注明出处:Bin的专栏,http://blog.csdn.net/xbinworld 本文是对参考资料中多篇关于sampling的内容进行总结+搬运,方便以后自己翻阅.其实参考资料中的资料写的比 ...
随机推荐
- Vue 系列(一): Vue + Echarts 开发可复用的柱形图组件
目录 前置条件 安装echarts 引入echarts 柱形图组件开发 在何时初始化组件? 完整的代码 记得注册组件!!! 本文归柯三(kesan)所有,转载请注明出处 https://www.cnb ...
- qt嵌入式html和本地c++通信方式
前沿:我们在做qt项目的时候,通常会把某个html网页直接显示到应用程序中.比如绘图.直接把html形式的图标嵌入到应用程序中 但是我们需要把数据从后台c++端传到html端,实现显示.qt实现了相关 ...
- lamp项目上线流程简述 (ubuntu16.04 )
1 新建一个sudo用户,而不是直接用root操作 ① 新建用户可参考 https://www.cnblogs.com/bushuwei/p/10880182.html ② 赋予sudo权限: ...
- linux复习3:linux字符界面的操作
一.前言 1.对linux服务器进行管理的时候,经常要进入字符界面进行操作,使用命令需要记住该命令的相关选项和参数.vi编辑器可以用于编辑任何ASCII文本,功能非常的强大,可以对文本进行创建.查找. ...
- 深入理解java虚拟机(3)垃圾收集器与内存分配策略
一.根搜索算法: (1)定义:通过一系列名为"GC Roots"的对象作为起点,从这些起点开始向下搜索,搜索走过的路径称为引用链,当一个对象到GC Roots没有任何引用链相连的时 ...
- mac osx sed 命令
$ sed -i "s/devicedemo/device/g" `grep devicedemo -rl ./` sed: 1: ".//.coveragerc&quo ...
- RubyGems 库发现了后门版本的网站开发工具 bootstrap-sass
安全研究人员在官方的 RubyGems 库发现了后门版本的网站开发工具 bootstrap-sass.该工具的下载量高达 2800 万次,但这并不意味着下载的所有版本都存在后门,受影响的版本是 v3. ...
- hive模拟数据
人员表 id,姓名,爱好,住址 1,小明1,lol-book-movie,beijing:mashibing-shanghai:pudong 2,小明2,lol-book-movie,beijing: ...
- PCRE does not support \L, \l, \N{name}, \U, or \u...
PCRE does not support \L, \l, \N{name}, \U, or \u... 参考文章:YCSUNNYLIFE 的 <php 正则匹配中文> 一.报错情景: 使 ...
- JS 验证码的实现
转自:https://github.com/ace0109/verifyCode 正要做一个验证码,网上找到这个还不错: gVerify.js: !(function(window, document ...