Auto-Encoders实战
Outline
Auto-Encoder
Variational Auto-Encoders
Auto-Encoder
创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs)
# [b,10] ==> [b,784]
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
for step, x in enumerate(train_db):
# [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, float(rec_loss))
# evaluation
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703
Auto-Encoders实战的更多相关文章
- [Python] 机器学习库资料汇总
声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...
- python数据挖掘领域工具包
原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...
- Theano3.1-练习之初步介绍
来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...
- [resource]Python机器学习库
reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...
- 机器学习——深度学习(Deep Learning)
Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...
- Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression
Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...
- [转]Python机器学习工具箱
原文在这里 Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...
- Deep Learning(4)
四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...
- 深度学习教程Deep Learning Tutorials
Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...
随机推荐
- NOIp 2014 联合权值 By cellur925
题目传送门 这题自己(真正)思考了很久(欣慰). (轻而易举)地发现这是一棵树后,打算从Dfs序中下功夫,推敲了很久规律,没看出来(太弱了). 开始手动枚举距离为2的情况,模模糊糊有了一些概念,但没有 ...
- eurekaclient向eurekaserver注册使用真实ip设置
有时候eureka.instance.prefer-ip-address=true不管用,解决办法如下.
- poj 3281 Dining (最大网络流)
题目链接: http://poj.org/problem?id=3281 题目大意: 有n头牛,f种食物,d种饮料,第i头牛喜欢fi种食物和di种饮料,每种食物或者饮料被一头牛选中后,就不能被其他的牛 ...
- [USACO 2012 Open Gold] Bookshelf【优化dp】
传送门1:http://www.usaco.org/index.php?page=viewproblem2&cpid=138 传送门2:http://www.lydsy.com/JudgeOn ...
- PHP 官方说明
http://php.net/manual/en/mysqli.affected-rows.php The above examples will output: Affected rows (INS ...
- [BZOJ1878][SDOI2009]HH的项链 莫队
题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=1878 不带修改的莫队,用一个桶记录一下当前区间中每种颜色的数量就可以做到$O(1)$更新了 ...
- Android EditText 输入金额(小数点后两位)
EditText edit = new EditText(context); InputType.TYPE_NUMBER_FLAG_DECIMAL //小数点型 InputType.TYPE_CLAS ...
- js中cookie的操作
JavaScript中的另一个机制:cookie,则可以达到真正全局变量的要求. cookie是浏览器 提供的一种机制,它将document 对象的cookie属性提供给JavaScript.可以由J ...
- thinkphp5 404 file_put_contents 无法打开流:权限被拒绝
如果你用TP的时间比较长,或者说你比较了解TP的人都会知道,TP的runtime它需要的权限是很大的,如果你只给一般权限肯定是不行的,通常都是给runtime权限:777: linux命令如下: cd ...
- iTOP-4412开发板-LinuxC-继电器模块的测试例程
平台:iTOP-4412开发板 实现:继电器模块测试例程 继电器的 C 的测试程序,C 测试程序可以在 Android系统,Qt 系统以及最小 linux 系统上运行,文档以 Android 系统上测 ...