import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data #设置输入参数
batch_size = 128
test_size = 256 # 初始化权值与定义网络结构,建构一个3个卷积层和3个池化层,一个全连接层和一个输出层的卷积神经网络
# 首先定义初始化权重函数
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01)) # 第一组卷积层以及池化层,最后 droupout是为了防止过拟合,在模型训练的时候丢掉一些神经元
# padding表示对边界的处理,SAME表示卷积的输入和输出保持同样尺寸
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
l1a = tf.nn.relu(tf.nn.conv2d(X, w,strides=[1, 1, 1, 1], padding='SAME'))
# l1 shape=(?, 14, 14, 32)
l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
l1 = tf.nn.dropout(l1, p_keep_conv)
# 第二组卷积层及池化层,最后dropout一些神经元
# l2a shape=(?, 14, 14, 64)
l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
# l2 shape=(?, 7, 7, 64)
l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
l2 = tf.nn.dropout(l2, p_keep_conv) # 第三组卷积神经网络及池化层,同样,最后dropout一些神经元
# l3a shape=(?, 7, 7, 128)
l3a = tf.nn.relu(tf.nn.conv2d(l2, w3,strides=[1, 1, 1, 1], padding='SAME'))
# l3 shape=(?, 4, 4, 128)
l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
# reshape to (?, 2048)
l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])
l3 = tf.nn.dropout(l3, p_keep_conv)
# 全连接层
l4 = tf.nn.relu(tf.matmul(l3, w4))
l4 = tf.nn.dropout(l4, p_keep_hidden)
# 输出层
pyx = tf.matmul(l4, w_o)
return pyx # 导入数据
mnist = input_data.read_data_sets("E:\\MNIST_data\\", one_hot=True)
# 定义四个变量,分别为输入训练图像矩阵及其标签,输入测试图像矩阵及其标签
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# -1表示布考虑输入图片的数量,28*28为图片的像素数,1是通道(channel)的数量,
# 因MNIST图片为黑白,彩色图片通道是3
# 28x28x1
trX = trX.reshape(-1, 28, 28, 1)
# 28x28x1
teX = teX.reshape(-1, 28, 28, 1) X = tf.placeholder("float", [None, 28, 28, 1])
# 10为识别图片的类别从0到9,共10个取值
Y = tf.placeholder("float", [None, 10]) # 定义模型函数
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w: 每一层权重
# 大小为3*3,输入的维度为1 ,输出维度为32
w = init_weights([3, 3, 1, 32])
# 大小为3*3,输入维度为32,输出维度为64
w2 = init_weights([3, 3, 32, 64])
# 大小为3*3,输入维度为64,输出维度为128
w3 = init_weights([3, 3, 64, 128])
# 全连接层,输入维度为128*4*4,也就是上一层的输出,输出维度为625
w4 = init_weights([128 * 4 * 4, 625])
# 输出层,输入的维度为625, 输出110维,代表10类(labels)
w_o = init_weights([625, 10]) # p_keep_conv,p_keep_hidden:dropout 保留神经元比例
# 定义dropout的占位符keep_conv,表示一层中有多少比例的神经元被保留,生成网络模型,得到预测数据
# 在训练的时候把设定比例的节点改为0,避免过拟合
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) # 定义损失函数,采用tf.nn.softmax_cross_entropy_with_logists,作为比较预测值和真实值的差距
# 定义训练操作(train_op) 采用RMSProp算法作为优化器,
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1) #在会话中定义图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
# you need to initialize all variabels
tf.global_variables_initializer().run()
for i in range(100):
training_batch=zip(range(0,len(trX),batch_size),range(batch_size,len(trX)+1,batch_size))
for start, end in training_batch:
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],p_keep_conv: 0.8, p_keep_hidden: 0.5})
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
#预测的时候设置为1 即对全部样本进行迭代训练
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==sess.run(predict_op, feed_dict={X: teX[test_indices],p_keep_conv: 1.0,p_keep_hidden: 1.0})))

吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  2. matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...

  3. TensorFlow实战第五课(MNIST手写数据集识别)

    Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

  4. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  5. TensorFlow系列专题(六):实战项目Mnist手写数据集识别

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

  6. Python之TensorFlow的卷积神经网络-5

    一.卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度 ...

  7. 吴裕雄--天生自然TensorFlow高层封装:使用TFLearn处理MNIST数据集实现LeNet-5模型

    # 1. 通过TFLearn的API定义卷机神经网络. import tflearn import tflearn.datasets.mnist as mnist from tflearn.layer ...

  8. TensorFlow——MNIST手写数据集

    MNIST数据集介绍 MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集. ...

  9. 吴裕雄 python 神经网络——TensorFlow 实现LeNet-5模型处理MNIST手写数据集

    import os import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import ...

随机推荐

  1. 题解 洛谷 P4145 【上帝造题的七分钟2 / 花神游历各国】

    题目 上帝造题的七分钟2 / 花神游历各国 题目背景 XLk觉得<上帝造题的七分钟>不太过瘾,于是有了第二部. 题目描述 "第一分钟,X说,要有数列,于是便给定了一个正整数数列. ...

  2. 素问 - 使用 PE、PB 做估值

    摘自<小韭的学习圈> Q 哪些行业用PE看合适,哪些用PB看合适啊?其中的大致逻辑是什么? A PE = 股价 / 每股收益 使用PE的逻辑是,我们认为一个股票有价值,是因为公司未来能赚钱 ...

  3. 编码 - 坑 - win10 下采用 utf-8, 导致 gitbash 中文字体异常, 待解决

    blog01 概述 使用 git 中, 遇到一个坑 背景 最近遇到一个 编码转换 问题 本来也 一知半解 要是有人能给我讲讲就好了 环境 win10 1903 git 2.20.1 1. 问题 概述 ...

  4. StreamPipes

    MQTT is a machine-to-machine (M2M)/"Internet of Things" connectivity protocol. It was desi ...

  5. SigXplorer设置延时及Local_Global

    通过SigXplorer设置绝对延时和相对延时及对Local-Global的理解 一.基本理解 (感觉可能有偏差) 在于博士的教程第44和45讲中,分别对绝对延时和相对延时进行了设置,通过SigXpl ...

  6. sftp,ftp文件下载

    一.sftp工具类 package com.ztesoft.iotcmp.util; import com.jcraft.jsch.ChannelSftp; import com.jcraft.jsc ...

  7. Sobel边缘检测算法

    索贝尔算子(Sobel operator)主要用作边缘检测,在技术上,它是一离散性差分算子,用来运算图像亮度函数的灰度之近似值.在图像的任何一点使用此算子,将会产生对应的灰度矢量或是其法矢量 Sobe ...

  8. Java:面向对象的编程语言

    java是面向对象的编程语言 Object,就是指面向对象的对象,对象就是实例. 在java里,对象是类的一个具体实例.就像:人,指一个类.你.我.他.张三.李四.王五等则是一个个具体的实例,也就是j ...

  9. SFSA

    #include<stdio.h> #include<string.h> #include<math.h> #include<iostream> #in ...

  10. 【转载】在windows下使用gcc编译jni的简单教程

    转自:http://veikr.com/201207/windows_gcc_jni.html 1.安装MinGW,这个可以为windows提供gcc编译环境. 到http://sourceforge ...