tensorflow MNIST Convolutional Neural Network

MNIST CNN 包含的几个部分:

  • Weight Initialization
  • Convolution and Pooling
  • Convolution layer
  • Fully connected layer
  • Readout Layer

直接上tensorflow 给的示例:

先读入数据:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
import time
#Weight Initialization
#现在还没有值,只是计算图中的节点,直到‘tf.global_variables_initializer()’才初始化
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
#Convolution and Pooling
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

  • 第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,要求类型为float32和float64其中之一
  • 第二个参数filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,有一个地方需要注意,第三维in_channels,就是参数input的第四维
  • 第三个参数strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4
  • 第四个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式(后面会介绍)
  • 第五个参数:use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true

结果返回一个Tensor,这个输出,就是我们常说的feature map

#Input	(placeholder)
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32,shape=[None,10])
#Convolution layer
x_image = tf.reshape(x, [-1,28,28,1]) W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #维度为[-1,28,28,32]
h_pool1 = max_pool_2x2(h_conv1)#维度为[-1,14,14,32] W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) #维度为[-1,14,14,64]
h_pool2 = max_pool_2x2(h_conv2) #维度为[-1,7,7,64]

tf.reshape(x, [-1,28,28,1])其中的-1表示由后面的几个维度来确定,

例如:

t=[[1, 2], [3, 4], [5, 6], [7, 8]] ,那么t的维度是[4,2]

(1) reshape(t,[2,4])后,t为[[1, 2, 3, 4], [5, 6, 7, 8]]

(2) reshape(t,[-1,4])后,t同样为[[1, 2, 3, 4], [5, 6, 7, 8]],所以这里的-1实际上为2。

#Fully connected layer
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
#dropout
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

训练时使用dropout,减少过拟合

#Readout	Layer
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
#Training and Evaluation
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) #Evaluation
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #optimizer
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #accuracy
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
t1=time.time()
for i in range(4000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
print("step %d, training accuracy %g"%(i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels,keep_prob: 1.0}))
t2=time.time()
print(t2-t1)

tensorflow MNIST Convolutional Neural Network的更多相关文章

  1. Tensorflow - Implement for a Convolutional Neural Network on MNIST.

    Coding according to TensorFlow 官方文档中文版 中文注释源于:tf.truncated_normal与tf.random_normal TF-卷积函数 tf.nn.con ...

  2. Convolutional Neural Network in TensorFlow

    翻译自Build a Convolutional Neural Network using Estimators TensorFlow的layer模块提供了一个轻松构建神经网络的高端API,它提供了创 ...

  3. 卷积神经网络(Convolutional Neural Network,CNN)

    全连接神经网络(Fully connected neural network)处理图像最大的问题在于全连接层的参数太多.参数增多除了导致计算速度减慢,还很容易导致过拟合问题.所以需要一个更合理的神经网 ...

  4. 【转载】 卷积神经网络(Convolutional Neural Network,CNN)

    作者:wuliytTaotao 出处:https://www.cnblogs.com/wuliytTaotao/ 本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,欢迎 ...

  5. ASPLOS'17论文导读——SC-DCNN: Highly-Scalable Deep Convolutional Neural Network using Stochastic Computing

    今年去参加了ASPLOS 2017大会,这个会议总体来说我感觉偏系统和偏软一点,涉及硬件的相对少一些,对我这个喜欢算法以及硬件架构的菜鸟来说并不算非常契合.中间记录了几篇相对比较有趣的paper,今天 ...

  6. 斯坦福大学卷积神经网络教程UFLDL Tutorial - Convolutional Neural Network

    Convolutional Neural Network Overview A Convolutional Neural Network (CNN) is comprised of one or mo ...

  7. Convolutional neural network (CNN) - Pytorch版

    import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms # ...

  8. 1 - ImageNet Classification with Deep Convolutional Neural Network (阅读翻译)

    ImageNet Classification with Deep Convolutional Neural Network 利用深度卷积神经网络进行ImageNet分类 Abstract We tr ...

  9. 论文阅读(Weilin Huang——【TIP2016】Text-Attentional Convolutional Neural Network for Scene Text Detection)

    Weilin Huang--[TIP2015]Text-Attentional Convolutional Neural Network for Scene Text Detection) 目录 作者 ...

随机推荐

  1. getopt命令

    最近学习了一下getopt(不是getopts)命令来处理执行shell脚本传入的参数,在此记录一下,包括长选项.短选项.以及选项的值出现的空格问题,最后写了个小的脚本来处理输入的参数 首先新建一个t ...

  2. 我与Git的那些破事--代码管理实践

    1. Git是什么? 作为一名程序猿,我相信大家都或多或少接触过git--分布式版本控制软件. 有人说,它是目前世界上最先进的分布式版本控制系统,我想说,是否最先进不知道,但确实好用,实用. 作为一款 ...

  3. SqlBulkCopy批量插入和索引的关系

    .net中批量插入基本都用SqlBulkCopy,速度很快,但是这几天发现个问题,2000数据居然15s,百思不得其解.经过大量测试,发现过多的索引和索引碎片会严重影响插入速度,表的数据量大小反而不会 ...

  4. 源码详解系列(六) ------ 全面讲解druid的使用和源码

    简介 druid是用于创建和管理连接,利用"池"的方式复用连接减少资源开销,和其他数据源一样,也具有连接数控制.连接可靠性测试.连接泄露控制.缓存语句等功能,另外,druid还扩展 ...

  5. Python中url标签使用详解

    url标签: 1.在模板中,我们经常要使用一些url,实现页面之间的跳转,比如某个a标签中需要定义href属性.当然如果通过硬编码的方式直接将这个url固定在里面也是可以的,但是这样的话,对于以后进行 ...

  6. xtrabackup备份还原mariadb数据库

    一.xtrabackup 简介 xtrabackup 是由percona公司开源免费的数据库热备软件,它能对InnoDB数据库和XtraDB存储引擎的数据库非阻塞地备份,对于myisam的备份同样需要 ...

  7. 一文带你看清HTTP所有概念

    上一篇文章我们大致讲解了一下 HTTP 的基本特征和使用,大家反响很不错,那么本篇文章我们就来深究一下 HTTP 的特性.我们接着上篇文章没有说完的 HTTP 标头继续来介绍(此篇文章会介绍所有标头的 ...

  8. 研究僧丨Window实用利器分享

    本人CS在读小硕,平时工作环境主要是win10加ubuntu,下面推荐一些我用过且觉得不错的应用. PS:我列举的应用基本被下面的网站收录,大家不妨去里面淘淘看. Windows 绝妙项目 Aweso ...

  9. 个人任务day6

    今日计划: 学会将网页放到公用网络上,并生成快捷方式. 昨日成果:完成登录页面.

  10. 「 从0到1学习微服务SpringCloud 」02 Eureka服务注册与发现

    系列文章(更新ing): 「 从0到1学习微服务SpringCloud 」01 一起来学呀! Spring Cloud Eureka 基于Netflix Eureka做了二次封装(Spring Clo ...