VAEs最早由“Diederik P. Kingma and Max Welling, “Auto-Encoding Variational Bayes, arXiv (2013)”和“Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra, “Stochastic Backpropagation and Approximate Inference in Deep Generative Models,” arXiv (2014)”同时发现。

原理:

对自编码器来说,它只是将输入数据投影到隐空间中,这些数据在隐空间中的位置是离散的,因此在此空间中进行采样,解码后的输出很可能是毫无意义的。

而对VAEs来说,它将输入数据转换成2个分布,一个是平均值的分布,一个是方差的分布(这就像高斯混合型了),添加上一些噪音,组合后,再进行解码。

如图(网上找的,应该是论文里的,暂时没看论文)

为什么分为2个分布?

可以这么理解:假设均值和方差都有n个,那么编码部分相当于用n个高斯分布(每个输入是不同权重的n个分布的组合)去模拟输入。

再通过一系列变换,转化为隐空间的若干维度,其每个维度可能具有某种意义。比如下面代码使用2维隐空间,可以看作是均值和方差维度。

方差部分指数化,保证非负。添加噪音让隐空间更具有意义的连续性。

然后我们从隐空间采样,由于隐空间具有意义上的连续性,那么解码后的东东就可能类似输入。

损失loss如何定义?为什么?

loss由2部分构成,第一部分就是解码输出与原始输入的loss,可以定义为交叉熵或者均方误差等。

第二部分是约束项。如上图黄色框,m平方作为L2正则化项,前2项可以看做方差减去其泰勒展开,当σ趋近0时,方差也即e^σ为1。那么最小化前2项必然使得σ趋近0(求导即可知)。

由此,这第二部分,m平方项约束使得均值为0,前2项约束使得方差为1。这样约束使得隐空间具有连续性,且强制输入数据在隐空间中的表示范围收拢。

这样在隐空间中2个数据表示的中间,就有一种过渡区域。如果仅以第一部分约束,效果可能就和自编码器一样了,模型会过拟合。


下面进入代码部分

以MNIST数据集作为训练样本。

from keras import backend as K

from keras.models import Model

from keras.metrics import binary_crossentropy

import numpy as np

from keras.layers import Conv2D,Flatten,Dense,Input,Lambda,Reshape,Conv2DTranspose,Layer

from keras.datasets import mnist

from keras.callbacks import EarlyStopping

编码器使用卷积层,输出2个部分

img_shape=(28,28,1)
batch_size=16
latent_dim=2 input_img=Input(shape=img_shape)
x=Conv2D(32,3,padding='same',activation='relu')(input_img)# 28,28,32
x=Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)# 14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
x=Conv2D(64,3,padding='same',activation='relu')(x)#14,14,64
# 保存Flatten之前的shape
shape_before_flattening=K.int_shape(x)
x=Flatten()(x)#14*14*64
x=Dense(32,activation='relu')(x)#
# 将输入图像拆分为2个向量
z_mean=Dense(latent_dim)(x)#
z_log_var=Dense(latent_dim)(x)

定义采样方法

def sampling(args):
z_mean,z_log_var=args
# 得到一个平均值为0,方差为1的正态分布,shape为(?,2)
epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0,stddev=1.)#K.shape返回仍是tensor
# tensor*tensor为elementwise操作
return z_mean+K.exp(z_log_var)*epsilon
z=Lambda(sampling)([z_mean,z_log_var])# 采样

解码

# 解码过程,逆操作
decode_input=Input(K.int_shape(z)[1:])
# np.prod表示对数组某个axis进行乘法操作,如果axis不指定,则将所有的元素乘积返回一个值
x=Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decode_input)#14*14*64
# 逆Flatten操作
x=Reshape(shape_before_flattening[1:])(x)#14,14,64
# 反卷积,strides=2将14*14变为28*28,跟Conv2D相反
x=Conv2DTranspose(32,3,padding='same',activation='relu',strides=2)(x)#28,28,32
# 注意这里的激活函数
x=Conv2D(1,3,padding='same',activation='sigmoid')(x)#28,28,1
# 解码model
decoder=Model(decode_input,x)
# 解码后的图片数据
z_decoded=decoder(z)

定义loss,使用一个自定义layer实现

class CustomVariationalLayer(Layer):
def vae_loss(self,x,z_decoded):
x=K.flatten(x)
z_decoded=K.flatten(z_decoded)
# loss为原始输入和编码-解码后的输出比较
xent_loss=binary_crossentropy(x,z_decoded)
# 约束
# mean部分表示L2正则损失,K.exp(z_log_var)-(1+z_log_var)保证方差为1,如果不约束,网络可能偷懒
kl_loss=5e-4*K.mean(K.exp(z_log_var)-(1+z_log_var)+K.square(z_mean),axis=-1)
return K.mean(xent_loss+kl_loss) def call(self,inputs):
x=inputs[0]
z_decoded=inputs[1]
loss=self.vae_loss(x,z_decoded)
# 继承方法
self.add_loss(loss,inputs=inputs)#将根据inputs计算的损失loss加到本layer
return x #不用,但是需要返回点啥 y=CustomVariationalLayer()([input_img,z_decoded])

加载数据,定义、训练模型

(x_train,y_train),(x_test,y_test)=mnist.load_data()

x_train=x_train.astype('float32')/255.
# 表示添加一个通道维度,通道数为1(颜色只有一种模式)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.astype('float32')/255.
x_test=x_test.reshape(x_test.shape+(1,))
vae=Model(input_img,y)
# 自定义层y里面已经包含了loss,这里不需要指定
vae.compile(optimizer='rmsprop',loss=None)
# 不需要标签,所以y为None,我们只需要知道一个图片的原始输入是否和编码-解码后的输出一致
vae.fit(x=x_train,y=None,shuffle=True,epochs=10,batch_size=batch_size,validation_data=(x_test,None),callbacks=[EarlyStopping(patience=2)],verbose=2)

测试

import matplotlib.pyplot as plt
from scipy.stats import norm # 潜空间中任意矢量可以解码成数字
n = 10
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# norm.ppf([v1,v2...])表示正态分布积分值为vi时,对应的x轴坐标值xi
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))#可以看作均值
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))#方差
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
# np.tile将数组重复n次,如[1,2]->[1,2,1,2]。然后reshape到输入格式
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
# 因为x_decoded为16个相同矢量得到的推导,取第一个就行,再将 28*28*1 reshape到 28*28
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

结果如下,可以看到,图片是连续变化的。

VAEs(变分自编码)之keras实践的更多相关文章

  1. Keras实践:模型可视化

    Keras实践:模型可视化 安装Graphviz 官方网址为:http://www.graphviz.org/.我使用的是mac系统,所以我分享一下我使用时遇到的坑. Mac安装时在终端中执行: br ...

  2. Keras实践:实现非线性回归

    Keras实践:实现非线性回归 代码 import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import ke ...

  3. GAN(生成对抗网络)之keras实践

    GAN由论文<Ian Goodfellow et al., “Generative Adversarial Networks,” arXiv (2014)>提出. GAN与VAEs的区别 ...

  4. 分享几个 PHP 编码的最佳实践

    对于初学者而言,可能很难理解为什么某些做法更安全. 但是,以下一些技巧可能超出了 PHP 的范围. 始终使用大括号 让我们看下面的代码: if (isset($condition) && ...

  5. 2.keras实现-->字符级或单词级的one-hot编码 VS 词嵌入

    1. one-hot编码 # 字符集的one-hot编码 import string samples = ['zzh is a pig','he loves himself very much','p ...

  6. ​结合异步模型,再次总结Netty多线程编码最佳实践

    更多技术分享可关注我 前言 本文重点总结Netty多线程的一些编码最佳实践和注意事项,并且顺便对Netty的线程调度模型,和异步模型做了一个汇总.原文:​​结合异步模型,再次总结Netty多线程编码最 ...

  7. 文本离散表示(二):新闻语料的one-hot编码

    上一篇博客介绍了文本离散表示的one-hot.TF-IDF和n-gram方法,在这篇文章里,我做了一个对新闻文本进行one-hot编码的小实践. 文本的one-hot相对而言比较简单,我用了两种方法, ...

  8. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  9. 算术编码Arithmetic Coding-高质量代码实现详解

    关于算术编码的具体讲解我不多细说,本文按照下述三个部分构成. 两个例子分别说明怎么用算数编码进行编码以及解码(来源:ARITHMETIC CODING FOR DATA COIUPRESSION): ...

随机推荐

  1. [2019牛客多校第二场][G. Polygons]

    题目链接:https://ac.nowcoder.com/acm/contest/882/G 题目大意:有\(n\)条直线将平面分成若干个区域,要求处理\(m\)次询问:求第\(q\)大的区域面积.保 ...

  2. Codeforces Round #590 (Div. 3) E. Special Permutations

    链接: https://codeforces.com/contest/1234/problem/E 题意: Let's define pi(n) as the following permutatio ...

  3. 一文学会redis从零到入门

    本文参照视屏学习整理:https://www.bilibili.com/video/av16841549/?p=9 相关软件.资料: 基本条件:有虚拟机或相关linux系统,熟悉基本linux操作 本 ...

  4. clone([Even[,deepEven]])克隆匹配的DOM元素并且选中这些克隆的副本。

    clone([Even[,deepEven]]) 概述 克隆匹配的DOM元素并且选中这些克隆的副本. 在想把DOM文档中元素的副本添加到其他位置时这个函数非常有用. 参数 EventsBooleanV ...

  5. PHP mysqli_fetch_assoc() 函数

    从结果集中取得一行作为关联数组: <?php // 假定数据库用户名:root,密码:123456,数据库:RUNOOB $con=mysqli_connect("localhost& ...

  6. HDU 5734 Acperience ( 数学公式推导、一元二次方程 )

    题目链接 题意 : 给出 n 维向量 W.要你构造一个 n 维向量 B = ( b1.b2.b3 ..... ) ( bi ∈ { +1, -1 } ) .然后求出对于一个常数 α > 0 使得 ...

  7. shell基础之一

    Shell脚本自动化管理系统的必备基础: vim编辑器的熟练使用,SSH终端及“.vimrc”的设置等等需要熟悉. 命令基础:Linux的150个常用命令的熟练使用 Linux的正则表达式以及三剑客( ...

  8. python3版 爬虫了解

    摘要:本文将使用Python3.4爬网页.爬图片.自动登录.并对HTTP协议做了一个简单的介绍.在进行爬虫之前,先简单来进行一个HTTP协议的讲解,这样下面再来进行爬虫就是理解更加清楚. 一.HTTP ...

  9. 微信小程序之scroll-view的坑

    好久没动小程序了,今天打算复习复习,结果刚写了一个scroll-view就遇到了一个坑,这怎么能忍,对比看文档也没发现那里出了问题,没办法只能去翻翻微信给的demo,发现scroll-view一个必要 ...

  10. Nginx之搭建反向代理实现tomcat分布式集群

    参考博文: Nginx反向代理实现Tomcat分布式集群 1. jdk 安装 jdk 下载网址: http://www.oracle.com/technetwork/java/javase/downl ...