import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image # 配置GPU或CPU设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 创建目录
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir) # 超参数设置
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3 # 获取数据集
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True) # 数据加载,按照batch_size大小加载,并随机打乱
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True) # 定义VAE类
# VAE model
class VAE(nn.Module):
def __init__(self, image_size=784, h_dim=400, z_dim=20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(image_size, h_dim)
self.fc2 = nn.Linear(h_dim, z_dim)
self.fc3 = nn.Linear(h_dim, z_dim)
self.fc4 = nn.Linear(z_dim, h_dim)
self.fc5 = nn.Linear(h_dim, image_size) # 编码 学习高斯分布均值与方差
def encode(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h), self.fc3(h) # 将高斯分布均值与方差参数重表示,生成隐变量z 若x~N(mu, var*var)分布,则(x-mu)/var=z~N(0, 1)分布
def reparameterize(self, mu, log_var):
std = torch.exp(log_var / 2)
eps = torch.randn_like(std)
return mu + eps * std
# 解码隐变量z
def decode(self, z):
h = F.relu(self.fc4(z))
return F.sigmoid(self.fc5(h)) # 计算重构值和隐变量z的分布参数
def forward(self, x):
mu, log_var = self.encode(x)# 从原始样本x中学习隐变量z的分布,即学习服从高斯分布均值与方差
z = self.reparameterize(mu, log_var)# 将高斯分布均值与方差参数重表示,生成隐变量z
x_reconst = self.decode(z)# 解码隐变量z,生成重构x’
return x_reconst, mu, log_var# 返回重构值和隐变量的分布参数 # 构造VAE实例对象
model = VAE().to(device)
print(model)
# VAE( (fc1): Linear(in_features=784, out_features=400, bias=True)
# (fc2): Linear(in_features=400, out_features=20, bias=True)
# (fc3): Linear(in_features=400, out_features=20, bias=True)
# (fc4): Linear(in_features=20, out_features=400, bias=True)
# (fc5): Linear(in_features=400, out_features=784, bias=True)) # 选择优化器,并传入VAE模型参数和学习率
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#开始训练
for epoch in range(num_epochs):
for i, (x, _) in enumerate(data_loader):
# 前向传播
x = x.to(device).view(-1, image_size)# 将batch_size*1*28*28 ---->batch_size*image_size 其中,image_size=1*28*28=784
x_reconst, mu, log_var = model(x)# 将batch_size*748的x输入模型进行前向传播计算,重构值和服从高斯分布的隐变量z的分布参数(均值和方差) # 计算重构损失和KL散度
# Compute reconstruction loss and kl divergence
# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
# 重构损失
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
# KL散度
kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # 反向传播与优化
# 计算误差(重构误差和KL散度值)
loss = reconst_loss + kl_div
# 清空上一步的残余更新参数值
optimizer.zero_grad()
# 误差反向传播, 计算参数更新值
loss.backward()
# 将参数更新值施加到VAE model的parameters上
optimizer.step()
# 每迭代一定步骤,打印结果值
if (i + 1) % 10 == 0:
print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
.format(epoch + 1, num_epochs, i + 1, len(data_loader), reconst_loss.item(), kl_div.item())) with torch.no_grad():
# Save the sampled images
# 保存采样值
# 生成随机数 z
z = torch.randn(batch_size, z_dim).to(device)# z的大小为batch_size * z_dim = 128*20
# 对随机数 z 进行解码decode输出
out = model.decode(z).view(-1, 1, 28, 28)
# 保存结果值
save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1))) # Save the reconstructed images
# 保存重构值
# 将batch_size*748的x输入模型进行前向传播计算,获取重构值out
out, _, _ = model(x)
# 将输入与输出拼接在一起输出保存 batch_size*1*28*(28+28)=batch_size*1*28*56
x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

大概长这么个样子:

附上一张结果图:

Variational Auto-encoder(VAE)变分自编码器-Pytorch的更多相关文章

  1. VAE变分自编码器

    我在学习VAE的时候遇到了很多问题,很多博客写的不太好理解,因此将很多内容重新进行了整合. 我自己的学习路线是先学EM算法再看的变分推断,最后学VAE,自我感觉这个线路比较好理解. 一.首先我们来宏观 ...

  2. VAE变分自编码器实现

    变分自编码器(VAE)组合了神经网络和贝叶斯推理这两种最好的方法,是最酷的神经网络,已经成为无监督学习的流行方法之一. 变分自编码器是一个扭曲的自编码器.同自编码器的传统编码器和解码器网络一起,具有附 ...

  3. 变分自编码器(Variational auto-encoder,VAE)

    参考: https://www.cnblogs.com/huangshiyu13/p/6209016.html https://zhuanlan.zhihu.com/p/25401928 https: ...

  4. (转) 变分自编码器(Variational Autoencoder, VAE)通俗教程

    变分自编码器(Variational Autoencoder, VAE)通俗教程 转载自: http://www.dengfanxin.cn/?p=334&sukey=72885186ae5c ...

  5. 变分自编码器(Variational Autoencoder, VAE)通俗教程

    原文地址:http://www.dengfanxin.cn/?p=334 1. 神秘变量与数据集 现在有一个数据集DX(dataset, 也可以叫datapoints),每个数据也称为数据点.我们假定 ...

  6. 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)

    变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...

  7. 基于变分自编码器(VAE)利用重建概率的异常检测

    本文为博主翻译自:Jinwon的Variational Autoencoder based Anomaly Detection using Reconstruction Probability,如侵立 ...

  8. 变分推断到变分自编码器(VAE)

    EM算法 EM算法是含隐变量图模型的常用参数估计方法,通过迭代的方法来最大化边际似然. 带隐变量的贝叶斯网络 给定N 个训练样本D={x(n)},其对数似然函数为: 通过最大化整个训练集的对数边际似然 ...

  9. 基于图嵌入的高斯混合变分自编码器的深度聚类(Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedding, DGG)

    基于图嵌入的高斯混合变分自编码器的深度聚类 Deep Clustering by Gaussian Mixture Variational Autoencoders with Graph Embedd ...

随机推荐

  1. Redis企业实战的几个坑

    一.前言 小伙伴们对Redis应该不陌生,Redis是系统必备的分布式缓存中间件,主要用来解决高并发下分担DB资源的负载,从而提升系统吞吐量. Redis支持多种数据类型,String(字符串).li ...

  2. Java 基础:单例模式 Singleton Pattern

    1.简介 单例模式(Singleton Pattern)是 Java 中最简单的设计模式之一.这种类型的设计模式属于创建型模式,它提供了一种创建对象的最佳方式. 这种模式涉及到一个单一的类,该类负责创 ...

  3. gstreamer的gst-inspect 和gst-launch

    用gstreamer架构做对媒体开发时,gst-inspect 和gst-launch是两个非常使用的小工具,前者是用于查询库中已经包含的所有element以及他们的详细信息,后者用于快速构建一条pi ...

  4. 实时查看linux下的日志

    cat /var/log/*.log 如果日志在更新,如何实时查看 tail -f /var/log/messages 还可以使用 watch -d -n 1 cat /var/log/message ...

  5. C#利用反射获取实体类的主键名称或者获取实体类的值

    //获取主键的 PropertyInfo PropertyInfo pkProp = ).FirstOrDefault(); //主键名称 var keyName=pkProp.Name; //实体类 ...

  6. PHP7 serialize_precision 配置不当导致 json_encode() 浮点小数溢出错误

    https://blog.csdn.net/moliyiran/article/details/81179825 感谢 @地狱星星:原因已找到, 该现象只出现在PHP 7.1+版本上建议使用默认值 s ...

  7. typeScript中的数据类型

    /* typeScript中的数据类型 typescript中为了使编写的代码更规范,更有利于维护,增加了类型校验,在typescript中主要给我们提供了以下数据类型 布尔类型(boolean) 数 ...

  8. Objective-C轻量级泛型

    在Apple发布Xcode7的时候,不仅把Swift编程语言升级到了2.0版本,而且还对Objective-C做了许多提升,包括引入__nonnull/__nullable.其中,对于Objectiv ...

  9. QQ第三方登录回调地址的问题

    如题,维护以前的项目,发现原来QQ的第三方登录竟然失败了.回调地址的问题 原来是以前的规则变了.好吧,那就改,谁叫我不是改变规则的人. 中途浪费了点时间,项目很大,我一下也找不到项目里那个接口调用的, ...

  10. ES6深入浅出-11 ES6新增的API(上)-2.Array新增API

    Array.form 把不是数组的东西变成数组.最常见的就是把伪数组变成数组 那么什么是伪数组 这就是伪数组,因为它不是继承自Array的原型的对象.它只是一个看起来很像数组的数组 只看下面的代码.a ...