tensorflow MNIST Convolutional Neural Network

MNIST CNN 包含的几个部分:

  • Weight Initialization
  • Convolution and Pooling
  • Convolution layer
  • Fully connected layer
  • Readout Layer

直接上tensorflow 给的示例:

先读入数据:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
import time
#Weight Initialization
#现在还没有值,只是计算图中的节点,直到‘tf.global_variables_initializer()’才初始化
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
#Convolution and Pooling
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

  • 第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,要求类型为float32和float64其中之一
  • 第二个参数filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,有一个地方需要注意,第三维in_channels,就是参数input的第四维
  • 第三个参数strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4
  • 第四个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式(后面会介绍)
  • 第五个参数:use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true

结果返回一个Tensor,这个输出,就是我们常说的feature map

#Input	(placeholder)
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32,shape=[None,10])
#Convolution layer
x_image = tf.reshape(x, [-1,28,28,1]) W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #维度为[-1,28,28,32]
h_pool1 = max_pool_2x2(h_conv1)#维度为[-1,14,14,32] W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) #维度为[-1,14,14,64]
h_pool2 = max_pool_2x2(h_conv2) #维度为[-1,7,7,64]

tf.reshape(x, [-1,28,28,1])其中的-1表示由后面的几个维度来确定,

例如:

t=[[1, 2], [3, 4], [5, 6], [7, 8]] ,那么t的维度是[4,2]

(1) reshape(t,[2,4])后,t为[[1, 2, 3, 4], [5, 6, 7, 8]]

(2) reshape(t,[-1,4])后,t同样为[[1, 2, 3, 4], [5, 6, 7, 8]],所以这里的-1实际上为2。

#Fully connected layer
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
#dropout
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

训练时使用dropout,减少过拟合

#Readout	Layer
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
#Training and Evaluation
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) #Evaluation
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #optimizer
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #accuracy
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
t1=time.time()
for i in range(4000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
print("step %d, training accuracy %g"%(i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels,keep_prob: 1.0}))
t2=time.time()
print(t2-t1)

tensorflow MNIST Convolutional Neural Network的更多相关文章

  1. Tensorflow - Implement for a Convolutional Neural Network on MNIST.

    Coding according to TensorFlow 官方文档中文版 中文注释源于:tf.truncated_normal与tf.random_normal TF-卷积函数 tf.nn.con ...

  2. Convolutional Neural Network in TensorFlow

    翻译自Build a Convolutional Neural Network using Estimators TensorFlow的layer模块提供了一个轻松构建神经网络的高端API,它提供了创 ...

  3. 卷积神经网络(Convolutional Neural Network,CNN)

    全连接神经网络(Fully connected neural network)处理图像最大的问题在于全连接层的参数太多.参数增多除了导致计算速度减慢,还很容易导致过拟合问题.所以需要一个更合理的神经网 ...

  4. 【转载】 卷积神经网络(Convolutional Neural Network,CNN)

    作者:wuliytTaotao 出处:https://www.cnblogs.com/wuliytTaotao/ 本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,欢迎 ...

  5. ASPLOS'17论文导读——SC-DCNN: Highly-Scalable Deep Convolutional Neural Network using Stochastic Computing

    今年去参加了ASPLOS 2017大会,这个会议总体来说我感觉偏系统和偏软一点,涉及硬件的相对少一些,对我这个喜欢算法以及硬件架构的菜鸟来说并不算非常契合.中间记录了几篇相对比较有趣的paper,今天 ...

  6. 斯坦福大学卷积神经网络教程UFLDL Tutorial - Convolutional Neural Network

    Convolutional Neural Network Overview A Convolutional Neural Network (CNN) is comprised of one or mo ...

  7. Convolutional neural network (CNN) - Pytorch版

    import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms # ...

  8. 1 - ImageNet Classification with Deep Convolutional Neural Network (阅读翻译)

    ImageNet Classification with Deep Convolutional Neural Network 利用深度卷积神经网络进行ImageNet分类 Abstract We tr ...

  9. 论文阅读(Weilin Huang——【TIP2016】Text-Attentional Convolutional Neural Network for Scene Text Detection)

    Weilin Huang--[TIP2015]Text-Attentional Convolutional Neural Network for Scene Text Detection) 目录 作者 ...

随机推荐

  1. MinIO 搭建使用

    MinIO简介¶ MinIO 是一款基于Go语言的高性能对象存储服务,在Github上已有19K+Star.它采用了Apache License v2.0开源协议,非常适合于存储大容量非结构化的数据, ...

  2. 深入浅出 Typescript 学习笔记

    TypeScript 是 JavaScript 的一个超集,支持 ECMAScript 6 标准. TypeScript 由微软开发的自由和开源的编程语言. TypeScript 设计目标是开发大型应 ...

  3. spring boot集成spring-boot-starter-mail邮件功能

    前情提要 以目前IT系统功能来看,邮件功能是非常重要的一个功能.例如:找回密码.邮箱验证,邮件动态码.忘记密码,邮件营销等,都需要用到邮件功能.结合当下最流行的spring boot微服务,推出了sp ...

  4. Java String类相关知识梳理(含字符串常量池(String Pool)知识)

    目录 1. String类是什么 1.1 定义 1.2 类结构 1.3 所在的包 2. String类的底层数据结构 3. 关于 intern() 方法(重点) 3.1 作用 3.2 字符串常量池(S ...

  5. java通过word模板生成word文档

    介绍 上次公司项目需要一个生成word文档的功能,有固定的模板根据业务填充数据即可,由于从来没做过,项目也比较着急于是去网上找有没有合适的工具类,找了好几种,看到其中有freeMark模板生成比较靠谱 ...

  6. 两个大数相乘 - 高精度FFT

    HDU 1402 A * B Problem Plus Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (J ...

  7. 一文带你看清HTTP所有概念

    上一篇文章我们大致讲解了一下 HTTP 的基本特征和使用,大家反响很不错,那么本篇文章我们就来深究一下 HTTP 的特性.我们接着上篇文章没有说完的 HTTP 标头继续来介绍(此篇文章会介绍所有标头的 ...

  8. Java.前端.Layer.open.btn验证无效

    今天遇到了一个很可笑的问题,在.Layer弹窗open中设置了多个按钮,只有yes按钮有效,btn2点击后直接关闭弹窗,排查了2个小时后终于解决,就是btn2要return false! var in ...

  9. java 数组2

    一.创建异常 1.空指针异常 2.超出索引范围 二.遍历 for循环 三.求数组中的最大值 package cn.wt.day05.demon02; public class DemonArray03 ...

  10. Linux下安装JDK 1.8

    前言 JDK是 JAVA 的软件开发工具包,如果要使用JAVA来进行开发,或者部署基于其开发的应用,那么就需要安装JDK.本次将在Linux下安装JDK及配置环境. 本人环境:CentOS 7.3 6 ...