Pytorch入门之VAE
关于自编码器的原理见另一篇博客 : 编码器AE & VAE
这里谈谈对于变分自编码器(Variational auto-encoder)即VAE的实现。
1. 稀疏编码
首先介绍一下“稀疏编码”这一概念。
早期学者在黑白风景照片中可以提取到许多16*16像素的图像碎片。而这些图像碎片几乎都可由64种正交的边组合得到。而且组合出一张碎片所需的边的数目很少,即稀疏的。同时在音频中大多数声音也可由几种基本结构组合得到。这其实就是特征的稀疏表达。即使用少量的基本特征来组合更加高层抽象的特征。在神经网络中即体现出前一层是未加工的像素,而后一层就是对这些像素的非线性组合。
有监督情况下可以利用深层卷积网络来提取特征,而自编码器就是无监督情况下根据自身的高阶特征编码自己。自编码器是输入输出相同的神经网络。其特点是利用稀疏的高阶特征来重构自己。一般而言自编码器的中间隐层节点的数量要小于输入节点的数量,即实现降维过程。因为对于少于输入节点的隐藏层来说无法将输入的全部信息保留,只能优先选择部分重要的特征,而后利用这些特征来复原。此外我们可以给隐层的权重加上L2正则,正则项惩罚因子越大,接近于0的系数越多,从而特征更加稀疏!
关于自编码器我们可以加入一些限制使其实现不同的功能,例如去噪自编码(Denoising AutoEncoder)。输入是加了噪声的数据,而输出是原始数据,在学习过程中,只有学到更鲁棒、更频繁的特征模式才能将噪声略去,回复原始数据。如果自编码器的隐层只有一层,那么原理类似于主成分分析PCA。
HInton提出的DBN模型有多个隐含层,每个隐含层都是限制玻尔兹曼机RBM。DBN训练时需先对每两层间进行无监督的预训练,这一过程实为一个多层的自编码器,可以将每整个网络的权重初始化到一个理想的分布。最后通过反向传播算法调整模型权重,这个步骤会使用经过标注的信息来做监督性的分类训练。当年DBN给训练深度神经网络提供了可能性,它解决了网络过深带来的深度弥散。简言之:先用自编码器的方法进行无监督的预训练,提取特征并初始化权重,然后使用标注信息进行监督式的训练。
2. VAE工作流程
先看下图:
AE的工作其实是实现了 图片->向量->图片 这一过程。就是说给定一张图片编码后得到一个向量,然后将这一向量进行解码后就得到了原始的图片。这个解码后的图片和之前的原图一样吗?不完全一样。因为一般而言,如前所述是从低维隐层中恢复原图。但是AE另我们现在能训练任意多的图片,如果我们把这些图片的编码向量存在来,那以后就能通过这些编码向量来重构我们的图像,称之为标准自编码器。可这还不够,如果现在我随机拿出一个很离谱的向量直接另其解码,那解码出来的东西十有八九是无意义的东西。
所以我们希望AE编码出的code符合一种分布(eg:高斯混合模型),那么我们就可以从这个高斯分布任意采样出一个code,给这个code解码那么就会生成一张原图类似的图。而这个强迫分布就是VAE与AE的不同之处了。VAE的编码器输出包括两部分:m和σ。其中e是正态分布, c为编码结果。m、e、σ、c的形状一样,都为(batch_size,latent_code_num) 。这个latent_code_num就相当于高斯混合分布的高斯数量。每个高斯都有自己的均值、方差。所以共有latent_code_num个均值、方差。
接下来是VAE的损失函数:由两部分的和组成(bce_loss、kld_loss)。bce_loss即为binary_cross_entropy(二分类交叉熵)损失,即用于衡量原图与生成图片的像素误差。kld_loss即为KL-divergence(KL散度),用来衡量潜在变量的分布和单位高斯分布的差异。
3. Pytorch实现
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 10 20:48:03 2018 @author: lps
""" import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
import torchvision.datasets as dst
from torchvision.utils import save_image EPOCH = 15
BATCH_SIZE = 64
n = 2 # num_workers
LATENT_CODE_NUM = 32
log_interval = 10 transform=transforms.Compose([transforms.ToTensor()])
data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=False)
data_test = dst.MNIST('MNIST_data/', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=n,batch_size=BATCH_SIZE, shuffle=True) class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__() self.encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 128, kernel_size=3 ,stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
) self.fc11 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
self.fc12 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
self.fc2 = nn.Linear(LATENT_CODE_NUM, 128 * 7 * 7) self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
) def reparameterize(self, mu, logvar):
eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
z = mu + eps * torch.exp(logvar/2) return z def forward(self, x):
out1, out2 = self.encoder(x), self.encoder(x) # batch_s, 8, 7, 7
mu = self.fc11(out1.view(out1.size(0),-1)) # batch_s, latent
logvar = self.fc12(out2.view(out2.size(0),-1)) # batch_s, latent
z = self.reparameterize(mu, logvar) # batch_s, latent
out3 = self.fc2(z).view(z.size(0), 128, 7, 7) # batch_s, 8, 7, 7 return self.decoder(out3), mu, logvar def loss_func(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE+KLD vae = VAE().cuda()
optimizer = optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) def train(EPOCH):
vae.train()
total_loss = 0
for i, (data, _) in enumerate(train_loader, 0):
data = Variable(data).cuda()
optimizer.zero_grad()
recon_x, mu, logvar = vae.forward(data)
loss = loss_func(recon_x, data, mu, logvar)
loss.backward()
total_loss += loss.data[0]
optimizer.step() if i % log_interval == 0:
sample = Variable(torch.randn(64, LATENT_CODE_NUM)).cuda()
sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
save_image(sample.data.view(64, 1, 28, 28),
'result/sample_' + str(epoch) + '.png')
print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
epoch, i*len(data), len(train_loader.dataset),
100.*i/len(train_loader), loss.data[0]/len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, total_loss / len(train_loader.dataset))) for epoch in range(1, EPOCH):
train(epoch)
main.py
编解码器可由全连接或卷积网络实现。这里采用CNN。结果如下:
参考 :
《Tensoflow 实战》
Paper-Implementations
yunjey/pytorch-tutorial
Pytorch入门之VAE的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
随机推荐
- [BZOJ5248] 2018九省联考 D1T1 一双木棋 | 博弈论 状压DP
题面 菲菲和牛牛在一块\(n\)行\(m\)列的棋盘上下棋,菲菲执黑棋先手,牛牛执白棋后手. 棋局开始时,棋盘上没有任何棋子,两人轮流在格子上落子,直到填满棋盘时结束. 落子的规则是:一个格子可以落子 ...
- 洛谷P1600 天天爱跑步
天天放毒... 首先介绍一个树上差分. 每次进入的时候记录贡献,跟出来的时候的差值就是子树贡献. 然后就可以做了. 发现考虑每个人的贡献有困难. 于是考虑每个观察员的答案. 把路径拆成两条,以lca分 ...
- 【洛谷P1463】反素数
题目大意:给定 \(N < 2e9\),求不超过 N 的最大反素数. 题解: 引理1:不超过 2e9 的数的质因子分解中,最多有 10 个不同的质因子,且各个质因子的指数和不超过30. 引理2: ...
- 收藏:SQL重复记录查询 .
来自:http://blog.csdn.net/chinmo/article/details/2184020 1.查找表中多余的重复记录,重复记录是根据单个字段(peopleId)来判断select ...
- std::lock_guard和std::unique_lock
std::unique_lock也可以提供自动加锁.解锁功能,比std::lock_guard更加灵活 https://www.cnblogs.com/xudong-bupt/p/9194394.ht ...
- apigateway-kong(七)配置说明
这一部分应该在最开始介绍,但是我觉得在对kong有一定了解后再回头看下配置,会理解的更深刻.接下来对这个配置文件里的参数做个详细的解释便于更好的使用或优化kong网关. 目录 一.配置加载 二.验证配 ...
- Python中if-else的多种写法
a, b= 1, 2 将a和b两个变量中的最大值赋值给c (1)常规写法 if a>b: c = a else: c = b (2)表达式 c = a if a>b e ...
- 设计模式---单一职责模式之装饰模式(Decorator)
前提:"单一职责"模式 在软件组件的设计中,如果责任划分的不清晰,使用继承,得到的结果往往是随着需求的变化,子类急剧膨胀,同时充斥着重复代码,这时候的关键是划清责任 典型模式(表现 ...
- Dubbo协议
参考dubbo官方文档http://dubbo.apache.org/zh-cn/docs/user/references/protocol/dubbo.html dubbo共支持如下几种通信协议: ...
- CodeForces - 348D Turtles(LGV)
https://vjudge.net/problem/CodeForces-348D 题意 给一个m*n有障碍的图,求从左上角到右下角两条不相交路径的方案数. 分析 用LGV算法.从(1,1)-(n, ...