Auto-Encoders实战
Outline
Auto-Encoder
Variational Auto-Encoders
Auto-Encoder
创建编解码器
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
# [b,784] ==> [b,19]
h = self.encoder(inputs)
# [b,10] ==> [b,784]
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) multiple 236436
_________________________________________________________________
sequential_1 (Sequential) multiple 237200
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
训练
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(10):
for step, x in enumerate(train_db):
# [b,28,28]==>[b,784]
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x,
x_rec_logits,
from_logits=True)
rec_loss = tf.reduce_min(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, float(rec_loss))
# evaluation
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
# [b,784]==>[b,28,28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])
# [b,28,28] ==> [2b,28,28]
x_concat = tf.concat([x, x_hat], axis=0)
# x_concat = x # 原始图片
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8) # 保存为整型
if not os.path.exists('ae_images'):
os.mkdir('ae_images')
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703
Auto-Encoders实战的更多相关文章
- [Python] 机器学习库资料汇总
声明:以下内容转载自平行宇宙. Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: ...
- python数据挖掘领域工具包
原文:http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块:Numpy和Sc ...
- Theano3.1-练习之初步介绍
来自 http://deeplearning.net/tutorial/,虽然比较老了,不过觉得想系统的学习theano,所以需要从python--numpy--theano的顺序学习.这里的资料都很 ...
- [resource]Python机器学习库
reference: http://qxde01.blog.163.com/blog/static/67335744201368101922991/ Python在科学计算领域,有两个重要的扩展模块: ...
- 机器学习——深度学习(Deep Learning)
Deep Learning是机器学习中一个非常接近AI的领域,其动机在于建立.模拟人脑进行分析学习的神经网络,近期研究了机器学习中一些深度学习的相关知识,本文给出一些非常实用的资料和心得. Key W ...
- Deep Learning Tutorial - Classifying MNIST digits using Logistic Regression
Deep Learning Tutorial 由 Montreal大学的LISA实验室所作,基于Theano的深度学习材料.Theano是一个python库,使得写深度模型更容易些,也可以在GPU上训 ...
- [转]Python机器学习工具箱
原文在这里 Python在科学计算领域,有两个重要的扩展模块:Numpy和Scipy.其中Numpy是一个用python实现的科学计算包.包括: 一个强大的N维数组对象Array: 比较成熟的(广播 ...
- 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks
In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...
- Deep Learning(4)
四.拓展学习推荐 Deep Learning 经典阅读材料: The monograph or review paper Learning Deep Architectures for AI (Fou ...
- 深度学习教程Deep Learning Tutorials
Deep Learning Tutorials Deep Learning is a new area of Machine Learning research, which has been int ...
随机推荐
- bzoj 3830: [Poi2014]Freight【dp】
参考:https://blog.csdn.net/zqh_wz/article/details/52953516 妙啊 看成分段问题,因为火车只能一批一批的走(易证= =)设f[i]为到i为止的车都走 ...
- 洛谷 P3356 火星探险问题 【最大费用最大流】
输出方案好麻烦啊 拆点,石头的连(i,i',1,1)(i,i',inf,0)表示可以取一次价值1,空地直接连(i,i',inf,0),对于能走到的两个格子(不包括障碍),连接(i',j,inf,0), ...
- 洛谷 P2770 航空路线问题【最大费用最大流】
记得cnt=1!!因为是无向图所以可以把回来的路看成另一条向东的路.字符串用map处理即可.拆点限制流量,除了1和n是(i,i+n,2)表示可以经过两次,其他点都拆成(i,i+n,1),费用设为1,原 ...
- TRACE Method 网站漏洞,你关闭了吗[转]
危险:该漏洞可能篡改网页HTML 源码 最近采用360 web scan 对服务器进行扫描.发现漏洞.TRACE Method Enabled 安全打分98分.前一阵有网页JS被人篡改,可能就是从这个 ...
- Monkey Banana Problem LightOJ - 1004
Monkey Banana Problem LightOJ - 1004 错误记录: 1.数组开小2.每组数据数组没有清空 #include<cstdio> #include<cst ...
- XmlPullParser接口详述
带*的是非常重要的函数.点击有说明.setInputgetDepthisWhitespacegetTextisEmptyElementTaggetAttributeCountgetAttributeV ...
- Oracle报错:“ORA-18008: 无法找到 OUTLN 方案 ”的解决方案
Oracle报错:“ORA-18008: 无法找到 OUTLN 方案 ”的解决方案 2.修改replication_dependency_tracking参数 SQL> alter syst ...
- 211 Add and Search Word - Data structure design 添加与搜索单词 - 数据结构设计
设计一个支持以下两个操作的数据结构:void addWord(word)bool search(word)search(word) 可以搜索文字或正则表达式字符串,字符串只包含字母 . 或 a-z . ...
- android开发学习 ------- volley网络请求的实例
在 http://www.sojson.com/httpRequest/ 上对http进行访问,将此访问在android中的应用 ********************************* ...
- .NET 几种数据绑定控件的区别
GridView 控件 GridView 控件以表的形式显示数据,并提供对列进行排序.分页.翻阅数据以及编辑或删除单个记录的功能. 特征:一行一条记录,就像新闻列表一样:带分页功能. DataList ...