Tesorflow-自动编码器(AutoEncoder)
直接附上代码:
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data def xavier_init(fan_in,fan_out,constant=1):
low=-constant*np.sqrt(6.0/(fan_in+fan_out))
high=constant*np.sqrt(6.0/(fan_in+fan_out))
return tf.random_uniform((fan_in,fan_out),minval=low,maxval=high,dtype=tf.float32) class AdditiveGaussianNoiseAutoencoder(object):
def __init__(self,n_input,n_hidden,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(),scale=0.1):
self.n_input=n_input
self.n_hidden=n_hidden
self.transfer=transfer_function
self.scale=tf.placeholder(tf.float32)
self.training_scale=scale
network_weights=self._initialize_weights()
self.weights=network_weights self.x=tf.placeholder(tf.float32,[None,self.n_input])
self.hidden=self.transfer(tf.add(tf.matmul(self.x+scale*tf.random_normal((n_input,)),self.weights['w1']),self.weights['b1']))
self.reconstruction=tf.add(tf.matmul(self.hidden,self.weights['w2']),self.weights['b2'])
self.cost=0.5*tf.reduce_sum(tf.pow(tf.sub(self.reconstruction,self.x),2.0))
self.optimizer=optimizer.minimize(self.cost) init=tf.initialize_all_variables()
self.sess=tf.Session()
self.sess.run(init) def _initialize_weights(self):
all_weights=dict()
all_weights['w1']=tf.Variable(xavier_init(self.n_input,self.n_hidden))
all_weights['b1']=tf.Variable(tf.zeros([self.n_hidden],dtype=tf.float32))
all_weights['w2']=tf.Variable(tf.zeros([self.n_hidden,self.n_input],dtype=tf.float32))
all_weights['b2']=tf.Variable(tf.zeros([self.n_input],dtype=tf.float32)) return all_weights def partial_fit(self,X): cost,opt=self.sess.run((self.cost,self.optimizer),feed_dict={self.x:X,self.scale:self.training_scale}) return cost def calc_total_cost(self,X):
return self.sess.run(self.cost,feed_dict={self.x:X,self.scale:self.training_scale}) def transform(self,X):
return self.sess.run(self.hidden,feed_dict={self.x:X,self.scale:self.training_scale}) def generate(self,hidden=None):
if hidden is None:
hidden=np.random.normal(size=self.weights["b1"])
return self.sess.run(self.reconstruction,feed_dict={self.hidden:hidden}) def reconstruct(self,X):
return self.sess.run(self.reconstruction,feed_dict={self.x:X,self.scale:self.training_scale}) def getWeights(self):
return self.sess.run(self.weights['w1']) def getBiases(self):
return self.sess.run(self.weights['b1']) mnist=input_data.read_data_sets('MNIST_data',one_hot=True) def standard_scale(X_train,X_test):
preprocessor=prep.StandardScaler().fit(X_train)
X_train=preprocessor.transform(X_train)
X_test=preprocessor.transform(X_test)
return X_train,X_test def get_random_block_from_data(data,batch_size):
start_index=np.random.randint(0,len(data)-batch_size)
return data[start_index:(start_index+batch_size)] X_train,X_test=standard_scale(mnist.train.images,mnist.test.images)
n_samples=int(mnist.train.num_examples)
training_epochs=20
batch_size=128
diaplay_step=1
autoencoder=AdditiveGaussianNoiseAutoencoder(n_input=784,n_hidden=200,transfer_function=tf.nn.softplus,optimizer=tf.train.AdamOptimizer(learning_rate=0.001),scale=0.01)
for epoch in range(training_epochs):
avg_cost=0
total_batch=int(n_samples/batch_size)
for i in range(total_batch):
batch_xs=get_random_block_from_data(X_train,batch_size) cost=autoencoder.partial_fit(batch_xs)
avg_cost+=cost/n_samples*batch_size if epoch%diaplay_step==0:
print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost)) print("Total cost: "+str(autoencoder.calc_total_cost(X_test)))
Tesorflow-自动编码器(AutoEncoder)的更多相关文章
- VAE--就是AutoEncoder的编码输出服从正态分布
花式解释AutoEncoder与VAE 什么是自动编码器 自动编码器(AutoEncoder)最开始作为一种数据的压缩方法,其特点有: 1)跟数据相关程度很高,这意味着自动编码器只能压缩与训练数据相似 ...
- Machine Learning Algorithms Study Notes(6)—遗忘的数学知识
机器学习中遗忘的数学知识 最大似然估计( Maximum likelihood ) 最大似然估计,也称为最大概似估计,是一种统计方法,它用来求一个样本集的相关概率密度函数的参数.这个方法最早是遗传学家 ...
- Machine Learning Algorithms Study Notes(4)—无监督学习(unsupervised learning)
1 Unsupervised Learning 1.1 k-means clustering algorithm 1.1.1 算法思想 1.1.2 k-means的不足之处 1 ...
- Machine Learning Algorithms Study Notes(1)--Introduction
Machine Learning Algorithms Study Notes 高雪松 @雪松Cedro Microsoft MVP 目 录 1 Introduction 1 1.1 ...
- 【Todo】【转载】深度学习&神经网络 科普及八卦 学习笔记 & GPU & SIMD
上一篇文章提到了数据挖掘.机器学习.深度学习的区别:http://www.cnblogs.com/charlesblc/p/6159355.html 深度学习具体的内容可以看这里: 参考了这篇文章:h ...
- (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇
This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...
- SIGAI深度学习第四集 深度学习简介
讲授机器学习面临的挑战.人工特征的局限性.为什么选择神经网络.深度学习的诞生和发展.典型的网络结构.深度学习在机器视觉.语音识别.自然语言处理.推荐系统中的应用 大纲: 机器学习面临的挑战 特征工程的 ...
- AI面试必备/深度学习100问1-50题答案解析
AI面试必备/深度学习100问1-50题答案解析 2018年09月04日 15:42:07 刀客123 阅读数 2020更多 分类专栏: 机器学习 转载:https://blog.csdn.net ...
- 栈式自动编码器(Stacked AutoEncoder)
起源:自动编码器 单自动编码器,充其量也就是个强化补丁版PCA,只用一次好不过瘾. 于是Bengio等人在2007年的 Greedy Layer-Wise Training of Deep Netw ...
- 降噪自动编码器(Denoising Autoencoder)
起源:PCA.特征提取.... 随着一些奇怪的高维数据出现,比如图像.语音,传统的统计学-机器学习方法遇到了前所未有的挑战. 数据维度过高,数据单调,噪声分布广,传统方法的“数值游戏”很难奏效.数据挖 ...
随机推荐
- DataTable与结构不同实体类之间的转换
在实际开发过程中,或者是第三方公司提供的数据表结构,与我们系统中的实体类字段不对应,遇到这样我们怎么处理呢?可能有人会说,在转换时创建一个实体对象,对表里的数据逐行遍历来实例化这个实体对象不就完了.的 ...
- Maven类包冲突终极三大解决技巧 mvn dependency:tree
Maven对于新手来说是<步步惊心>,因为它包罗万象,博大精深,因为当你初来乍到时,你就像一个进入森林的陌生访客一样迷茫. Maven对于老手来说是<真爱配方>,因为它无所不能 ...
- Ext JS v2.3.0 Ext.grid.ColumnModel renderer Record 获取列值
场景:设置某一列的值,但是需要获取其他列的值 {"header": '<s:property value="name" />', "wid ...
- [database] postgresql 外网访问
配置 环境 ubuntu 14.04 LTS 版本 postgresql version 9.3 修改监听地址 编辑 /etc/postgresql/9.3/main/postgresql.conf ...
- php CI框架 使用PDO 的连接配置
$db['default'] = array( 'dsn' => 'mysql:dbname=hejun;host=192.168.137.127', //'hostname' => '' ...
- CHNetRequest网络请求
Paste JSON as Code • quicktype 软件的使用 iOS开发:官方自带的JSON使用 JSON 数据解析 XML 数据解析 Plist 数据解析 NetRequest 网络数据 ...
- 自定义Cell的流程
1..h文件 // // 文 件 名:CHBackupGateWayCell.h // // 版权所有:Copyright © 2018 lelight. All rights reserved. / ...
- servlet-mysql实现简单用户登录注册
环境:IDEA Maven 效果截图: 项目结构: 类说明: ConnectionUtil:负责数据库连接和释放 UserDao:数据库增删改查操作 User:用户Bean,只用注册和登录的话可以不要 ...
- memcached装、启动和卸载
1.下载相关软件: 下载地址:http://download.csdn.net/download/wangshuxuncom/8249501: 2.解压获取到的压缩文件,将得到一个名为“memcach ...
- 再谈VS2010编译更高平台vs2012(v110),vs2015(v140)的objectARX程序
前段时间我贴了一篇vs2010批量编译vc6~vs2008的ARX版本,实际上那一篇是我在研究vs2010编译v110,v140平台的附带收获,正应了那句话,有心栽花花不开,无心插柳柳成荫,因为vs2 ...