自编码器是一种数据压缩算法,其中数据的压缩和解压缩函数是数据相关的、从样本中训练而来的。大部分自编码器中,压缩和解压缩的函数是通过神经网络实现的。

1. 使用卷积神经网络搭建自编码器

  • 导入MNIST数据集(灰度图,像素范围0~1)

    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data', validation_size=0)
  • 搭建网络
      inputs_ = tf.placeholder(tf.float32, (None, 28, 28, 1), name='inputs')
    targets_ = tf.placeholder(tf.float32, (None, 28, 28, 1), name='targets')
    ### Encoder
    conv1 = tf.layers.conv2d(inputs_, 16, (3,3), padding='same', activation=tf.nn.relu) # 28x28x16
    maxpool1 = tf.layers.max_pooling2d(conv1, (2,2), (2,2), padding='same') # 14x14x16
    conv2 = tf.layers.conv2d(maxpool1, 8, (3,3), padding='same', activation=tf.nn.relu) # 14x14x8
    maxpool2 = tf.layers.max_pooling2d(conv2, (2,2), (2,2), padding='same') # 7x7x8
    conv3 = tf.layers.conv2d(maxpool2, 8, (3,3), padding='same', activation=tf.nn.relu) # 7x7x8
    encoded = tf.layers.max_pooling2d(conv3, (2,2), (2,2), padding='same') # 4x4x8
    ### Decoder
    upsample1 = tf.image.resize_nearest_neighbor(encoded, (7,7)) # 7x7x8
    conv4 = tf.layers.conv2d(upsample1, 8, (3,3), padding='same', activation=tf.nn.relu) # 7x7x8
    upsample2 = tf.image.resize_nearest_neighbor(conv4, (14,14)) # 14x14x8
    conv5 = tf.layers.conv2d(upsample2, 8, (3,3), padding='same', activation=tf.nn.relu) # 14x14x8
    upsample3 = tf.image.resize_nearest_neighbor(conv5, (28,28)) # 28x28x8
    conv6 = tf.layers.conv2d(upsample3, 16, (3,3), padding='same', activation=tf.nn.relu) # 28x28x16
    logits = tf.layers.conv2d(conv6, 1, (3,3), padding='same', activation=None) # 28x28x1
    decoded = tf.nn.sigmoid(logits, name='decoded') # 28x28x1
    ### Loss and Optimization:
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_, logits=logits)
    cost = tf.reduce_mean(loss)
    opt = tf.train.AdamOptimizer(0.001).minimize(cost)

    模型在解码部分使用的是upsample+convolution而不是transposed convolution(参考文献

  • 训练网络
      sess = tf.Session()
    epochs = 20
    batch_size = 200
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
    for ii in range(mnist.train.num_examples//batch_size):
    batch = mnist.train.next_batch(batch_size)
    imgs = batch[0].reshape((-1, 28, 28, 1))
    batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: imgs, targets_: imgs})
    print("Epoch: {}/{}...".format(e+1, epochs), "Training loss: {:.4f}".format(batch_cost))
  • 检验网络
      fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
    in_imgs = mnist.test.images[:10]
    reconstructed, compressed = sess.run([decoded, encoded], feed_dict={inputs_: in_imgs.reshape((10, 28, 28, 1))})
    # plot
    for images, row in zip([in_imgs, reconstructed], axes):
    for img, ax in zip(images, row):
    ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0.1)
    sess.close()

2. 使用自编码器降噪

  • 搭建网络(同上但feature map的个数由16-8-8-8-8-16变为32-32-16-16-32-32)
  • 训练网络
      sess = tf.Session()
    epochs = 100
    batch_size = 200
    # Set's how much noise we're adding to the MNIST images
    noise_factor = 0.5
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
    for ii in range(mnist.train.num_examples//batch_size):
    batch = mnist.train.next_batch(batch_size)
    # Get images from the batch
    imgs = batch[0].reshape((-1, 28, 28, 1))
    # Add random noise to the input images
    noisy_imgs = imgs + noise_factor * np.random.randn(*imgs.shape)
    # Clip the images to be between 0 and 1
    noisy_imgs = np.clip(noisy_imgs, 0., 1.)
    # Noisy images as inputs, original images as targets
    batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: noisy_imgs, targets_: imgs})
    print("Epoch: {}/{}...".format(e+1, epochs), "Training loss: {:.4f}".format(batch_cost))
  • 检验网络
      fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))
    in_imgs = mnist.test.images[:10]
    noisy_imgs = in_imgs + noise_factor * np.random.randn(*in_imgs.shape)
    noisy_imgs = np.clip(noisy_imgs, 0., 1.)
    reconstructed = sess.run(decoded, feed_dict={inputs_: noisy_imgs.reshape((10, 28, 28, 1))})
    for images, row in zip([noisy_imgs, reconstructed], axes):
    for img, ax in zip(images, row):
    ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0.1)
    sess.close()

使用Tensorflow搭建自编码器(Autoencoder)的更多相关文章

  1. TensorFlow实现自编码器及多层感知机

    1 自动编码机简介        传统机器学习任务在很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难 ...

  2. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  3. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  4. 用Tensorflow搭建神经网络的一般步骤

    用Tensorflow搭建神经网络的一般步骤如下: ① 导入模块 ② 创建模型变量和占位符 ③ 建立模型 ④ 定义loss函数 ⑤ 定义优化器(optimizer), 使 loss 达到最小 ⑥ 引入 ...

  5. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  6. 使用Tensorflow搭建回归预测模型之一:环境安装

    方法1:快速包安装 一.安装Anaconda 1.官网地址:https://www.anaconda.com/distribution/,选择其中一个版本下载即可,最好安装3.7版本,因为2.7版本2 ...

  7. 使用Tensorflow搭建回归预测模型之二:数据准备与预处理

    前言: 在前一篇中,已经搭建好了Tensorflow环境,本文将介绍如何准备数据与预处理数据. 正文: 在机器学习中,数据是非常关键的一个环节,在模型训练前对数据进行准备也预处理是非常必要的. 一.数 ...

  8. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  9. 用TensorFlow搭建一个万能的神经网络框架(持续更新)

    我一直觉得TensorFlow的深度神经网络代码非常困难且繁琐,对TensorFlow搭建模型也十分困惑,所以我近期阅读了大量的神经网络代码,终于找到了搭建神经网络的规律,各位要是觉得我的文章对你有帮 ...

随机推荐

  1. Go Pentester - HTTP Servers(2)

    Routing with the gorilla/mux Package A powerful HTTP router and URL matcher for building Go web serv ...

  2. fastjson将json字符串转化为java对象

    目录 一.导入一个fastjson的jar包 二.Json字符串格式 三.根据json的格式创建Java类 四.给java类的所有属性添加setter方法 五.转换为java对象 一.导入一个fast ...

  3. UVa 548 Tree(中序遍历+后序遍历)

    给一棵点带权(权值各不相同,都是小于10000的正整数)的二叉树的中序和后序遍历,找一个叶子使得它到根的路径上的权和最小.如果有多解,该叶子本身的权应尽量小.输入中每两行表示一棵树,其中第一行为中序遍 ...

  4. 一起学Blazor WebAssembly 开发(3)

    接着上篇,本篇开始讲下实现登录窗口,先看下大概的效果图: 打开的效果,没有美工美化 点登录校验得到不能为空 我在做blazor时用到了一个ui框架,这个框架名叫Ant Design blazor(ht ...

  5. 题解 洛谷 P3185 【[HNOI2007]分裂游戏】

    首先可以发现,当所有巧克力豆在最后一个瓶子中时,就无法再操作了,此时为必败状态. 注意到,对于每个瓶子里的巧克力豆,是可以在模\(2\)的意义下去考虑的,因为后手可以模仿先手的操作,所以就将巧克力豆个 ...

  6. springboot(三)SpringDataJPA完成CRUD

    参考博客—恒宇少年:https://www.jianshu.com/p/b6932740f3c0 纯洁的微笑:http://www.ityouknow.com/springboot/2016/08/2 ...

  7. 豆瓣 9.0 分的《Python学习知识手册》|百度网盘免费下载|

    豆瓣 9.0 分的<Python学习知识手册>|百度网盘免费下载| 提取码:nuak 这是之前入门学习Python时候的学习资料,非常全面,从Python基础.到web开发.数据分析.机器 ...

  8. .NET Core学习笔记(7)——Exception最佳实践

    1.为什么不要给每个方法都写try catch 为每个方法都编写try catch是错误的做法,理由如下: a.重复嵌套的try catch是无用的,多余的. 这一点非常容易理解,下面的示例代码中,O ...

  9. Kafka和SpringBoot

    事先必备: kafka已安装完成 1.目录结构 2.父pom <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns ...

  10. Python 字典(Dictionary) clear()方法

    Python 字典(Dictionary) clear()方法 描述 Python 字典(Dictionary) clear() 函数用于删除字典内所有元素.高佣联盟 www.cgewang.com ...