Auto-Encoders实战
Outline
Auto-Encoder
Variational Auto-Encoders
Auto-Encoder

创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs)
# [b,10] ==> [b,784]
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
for step, x in enumerate(train_db):
# [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, float(rec_loss))
# evaluation
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703
Auto-Encoders实战的更多相关文章
- [Python] 机器学习库资料汇总
声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...
- python数据挖掘领域工具包
原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...
- Theano3.1-练习之初步介绍
来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...
- [resource]Python机器学习库
reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...
- 机器学习——深度学习(Deep Learning)
Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...
- Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression
Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...
- [转]Python机器学习工具箱
原文在这里 Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...
- Deep Learning(4)
四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...
- 深度学习教程Deep Learning Tutorials
Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...
随机推荐
- 压力测试之jmeter使用
我很早之前就会使用jmeter,一直以为压力测试很简单,知道真正去做才明白,真正的压力测试并不只是会用jmeter而已.我现在才明白:会工具并不等同于会压力测试.对于压力测试需要补充的知识还有很多.. ...
- docker学习教程
我们的docker学习教程可以分为以下几个部分,分别是: 第一:docker基础学习 第二:docker日志管理 第三:docker监控管理 第四:docker三剑客之一:docker-machine ...
- libtool版本过新的问题
安装过程中出现: libtool: Version mismatch error. This is libtool 2.4.2, but the libtool: definition of th ...
- redis的安装使用以及一些常用的命令
Redis是一个key-value存储系统.并提供多种语言的API,我们可使用它构建高性能,可扩展的Web应用程序.目前越来越多的网站用它来当做缓存,减轻服务器的压力. 本文安装用的到redis是绿色 ...
- 关于OPPO手机的生存和程序员的发展
关于程序员私下讨论最多的话题,除了哪个编程最牛逼之外,哪款品牌的手机最牛逼也是我们谈论最多的话题之一吧!有的喜欢罗永浩,自然就是锤粉:有的喜欢苹果,称它为工业时代最优美的艺术品:当然,我想也有很多的人 ...
- 掌握Spark机器学习库-01
第1章 初识机器学习 在本章中将带领大家概要了解什么是机器学习.机器学习在当前有哪些典型应用.机器学习的核心思想.常用的框架有哪些,该如何进行选型等相关问题. 1-1 导学 1-2 机器学习概述 1- ...
- 关于mapState和mapMutations和mapGetters 和mapActions辅助函数的用法及作用(三)-----mapGetters
简单的理解: const getters = { newCount: function(state){ return state.count += 5; } } ------------------- ...
- jQuery核心语法
.each只是处理jQuery对象的方法,jQuery还提供了一个通用的jQuery.each方法,用来处理对象和数组的遍历 jQuery/($).each(array, callback )jQue ...
- python_MachineLearning_感知机PLA
感知机:线性二类分类器(linear binary classifier) 感知机(perceptron)是二类分类的线性模型,其输入为实例的特征向量,输出为实例的类别,取+1和-1二值.感知机对 ...
- swift 使用计算属性+结构管理内存
class GooClass { deinit { print("aaaaaaaa") } var str = "gooClass" } struct GooS ...