intractable棘手的,难处理的  posterior distributions后验分布 directed probabilistic有向概率

approximate inference近似推理  multivariate Gaussian多元高斯  diagonal对角 maximum likelihood极大似然

参考:https://blog.csdn.net/yao52119471/article/details/84893634

VAE论文所在讲的问题是:

我们现在就是想要训练一个模型P(x),并求出其参数Θ:

通过极大似然估计求其参数

Variational Inference

在论文中P(x)模型会被拆分成两部分,一部分由数据x生成潜在向量z,即pθ(z|X);一部分从z重新在重构数据x,即pθ(X|z)

实现过程则是希望能够使用一个qΦ(z|X)模型去近似pθ(z|X),然后作为模型的Encoder;后半部分pθ(X|z)则作为Decoder,Φ/θ表示参数,实现一种同时学习识别模型参数φ和参数θ的生成模型的方法,推导过程为:

现在问题就在于怎么进行求导,因为现在模型已经不是一个完整的P(x) = pθ(z|X) + pθ(X|z),现在变成了P(x) = qΦ(z|X) + pθ(X|z),那么如果对Φ求导就会变成一个问题,因此论文中就提出了一个reparameterization trick方法:

取样于一个标准正态分布来采样z,以此将qΦ(z|X) 和pθ(X|z)两个子模型通过z连接在了一起

最终的目标函数为:

因此目标函数 = 输入和输出x求MSELoss - KL(qΦ(z|X) || pθ(z))

在论文上对式子最后的KL散度 -KL(qΦ(z|X) || pθ(z))的计算有简化为:

多维KL散度的推导可见:KL散度

假设pθ(z)服从标准正态分布,采样ε服从标准正态分布满足该假设

简单代码实现:

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt class Encoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Encoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x)) class Decoder(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(Decoder, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x):
x = F.relu(self.linear1(x))
return F.relu(self.linear2(x)) class VAE(torch.nn.Module):
latent_dim = def __init__(self, encoder, decoder):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._enc_mu = torch.nn.Linear(, )
self._enc_log_sigma = torch.nn.Linear(, ) def _sample_latent(self, h_enc):
"""
Return the latent normal sample z ~ N(mu, sigma^)
"""
mu = self._enc_mu(h_enc)
log_sigma = self._enc_log_sigma(h_enc) #得到的值是loge(sigma)
sigma = torch.exp(log_sigma) # = e^loge(sigma) = sigma
#从均匀分布中取样
std_z = torch.from_numpy(np.random.normal(, , size=sigma.size())).float() self.z_mean = mu
self.z_sigma = sigma return mu + sigma * Variable(std_z, requires_grad=False) # Reparameterization trick def forward(self, state):
h_enc = self.encoder(state)
z = self._sample_latent(h_enc)
return self.decoder(z) # 计算KL散度的公式
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - ) if __name__ == '__main__': input_dim = *
batch_size = transform = transforms.Compose(
[transforms.ToTensor()])
mnist = torchvision.datasets.MNIST('./', download=True, transform=transform) dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
shuffle=True, num_workers=) print('Number of samples: ', len(mnist)) encoder = Encoder(input_dim, , )
decoder = Decoder(, , input_dim)
vae = VAE(encoder, decoder) criterion = nn.MSELoss() optimizer = optim.Adam(vae.parameters(), lr=0.0001)
l = None
for epoch in range():
for i, data in enumerate(dataloader, ):
inputs, classes = data
inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
optimizer.zero_grad()
dec = vae(inputs)
ll = latent_loss(vae.z_mean, vae.z_sigma)
loss = criterion(dec, inputs) + ll
loss.backward()
optimizer.step()
l = loss.data[]
print(epoch, l) plt.imshow(vae(inputs).data[].numpy().reshape(, ), cmap='gray')
plt.show(block=True)

VAE论文学习的更多相关文章

  1. Faster RCNN论文学习

    Faster R-CNN在Fast R-CNN的基础上的改进就是不再使用选择性搜索方法来提取框,效率慢,而是使用RPN网络来取代选择性搜索方法,不仅提高了速度,精确度也更高了 Faster R-CNN ...

  2. 《Explaining and harnessing adversarial examples》 论文学习报告

    <Explaining and harnessing adversarial examples> 论文学习报告 组员:裴建新   赖妍菱    周子玉 2020-03-27 1 背景 Sz ...

  3. 论文学习笔记 - 高光谱 和 LiDAR 融合分类合集

    A³CLNN: Spatial, Spectral and Multiscale Attention ConvLSTM Neural Network for Multisource Remote Se ...

  4. Apache Calcite 论文学习笔记

    特别声明:本文来源于掘金,"预留"发表的[Apache Calcite 论文学习笔记](https://juejin.im/post/5d2ed6a96fb9a07eea32a6f ...

  5. FactorVAE论文学习-1

    Disentangling by Factorising 我们定义和解决了从变量的独立因素生成的数据的解耦表征的无监督学习问题.我们提出了FactorVAE方法,通过鼓励表征的分布因素化且在维度上独立 ...

  6. GoogleNet:inceptionV3论文学习

    Rethinking the Inception Architecture for Computer Vision 论文地址:https://arxiv.org/abs/1512.00567 Abst ...

  7. IEEE Trans 2008 Gradient Pursuits论文学习

    之前所学习的论文中求解稀疏解的时候一般采用的都是最小二乘方法进行计算,为了降低计算复杂度和减少内存,这篇论文梯度追踪,属于贪婪算法中一种.主要为三种:梯度(gradient).共轭梯度(conjuga ...

  8. Raft论文学习笔记

    先附上论文链接  https://pdos.csail.mit.edu/6.824/papers/raft-extended.pdf 最近在自学MIT的6.824分布式课程,找到两个比较好的githu ...

  9. 论文学习-系统评估卷积神经网络各项超参数设计的影响-Systematic evaluation of CNN advances on the ImageNet

    博客:blog.shinelee.me | 博客园 | CSDN 写在前面 论文状态:Published in CVIU Volume 161 Issue C, August 2017 论文地址:ht ...

随机推荐

  1. python使用二分法实现在一个有序列表中查找指定的元素

    二分法是一种快速查找的方法,时间复杂度低,逻辑简单易懂,总的来说就是不断的除以2除以2... 例如需要查找有序list里面的某个关键字key的位置,那么首先确认list的中位数mid,下面分为三种情况 ...

  2. 深入理解flask 笔记

    ===sqlalchemy创建的数据模型中:1 字段是类属性   [模型中定义的字段是类属性,表单中定义的字段也是类字段] 2 若数据库不支持bool类型,则sqlalchemy会自动将bool转成0 ...

  3. 织梦DedeCMS会员空间内的文章列表无法分页的解决方法

    DedeCMS 5.7会员空间的文章列表分页显示不正常,总是显示0页0条记录错误.下面告诉大家如何解决这个问题: 找到并打开include/arc.memberlistview.class.php文件 ...

  4. http和https 握手过程

    这几天测试打印机一直出现打印延迟或者不打印的BUG.找了几天也没有发现为啥没有打印或者打印延迟.然后今天公司的研发大佬过来找问题,并开个会,瞬间所有的问题都找出了并且知道怎么解决了.大佬果然还是大佬. ...

  5. ML,DL核心数学及算法知识点总结

    ML,DL核心数学及算法知识点总结:https://mp.weixin.qq.com/s/bskyMQ2i1VMNiYKIvw_d7g

  6. HTML 010 radio

    Struts2单选按钮标签s:radio的使用及其设置默认值 转载atom168 发布于2014-12-01 15:40:59 阅读数 519  收藏 展开 首先在页面中引入struts标签库: &l ...

  7. git 查看项目代码统计命令

    git log --author="xxxxxxxx" --pretty=tformat: --numstat | gawk '{ add += $1 ; subs += $2 ; ...

  8. 洛谷P1706全排列问题

     P1706 全排列问题 题目描述 输出自然数1到n所有不重复的排列,即n的全排列,要求所产生的任一数字序列中不允许出现重复的数字. 输入输出格式 输入格式: n(1≤n≤9) 输出格式: 由1-n组 ...

  9. 29、Java虚拟机垃圾回收调优

    一.背景 如果在持久化RDD的时候,持久化了大量的数据,那么Java虚拟机的垃圾回收就可能成为一个性能瓶颈.因为Java虚拟机会定期进行垃圾回收,此时就会追踪所有的java对象, 并且在垃圾回收时,找 ...

  10. C语言strncasecmp()函数:比较字符串的前n个字符

    定义 int strncasecmp(const char *s1, const char *s2, size_t n); 描述 strncasecmp()用来比较参数s1 和s2 字符串前n个字符, ...