算的的上是自己搭建的第一个卷积神经网络。网络结构比较简单。

输入为单通道的mnist数据集。它是一张28*28,包含784个特征值的图片

我们第一层输入,使用5*5的卷积核进行卷积,输出32张特征图,然后使用2*2的池化核进行池化 输出14*14的图片

第二层 使用5*5的卷积和进行卷积,输出64张特征图,然后使用2*2的池化核进行池化 输出7*7的图片

第三层为全连接层 我们总结有 7*7*64 个输入,输出1024个节点 ,使用relu作为激活函数,增加一个keep_prob的dropout层

第四层为输出层,我们接收1024个输入,输出长度为10的one-hot向量。使用softmax作为激活函数

使用交叉熵作为损失函数

网络模型代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tempfile
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
FLAGS = None
def weight_variable(shape):
init=tf.truncated_normal(shape=shape,stddev=0.1,mean=1.)
return tf.Variable(init)
def bias_variable(shape):
init=tf.constant(0.1,shape=shape)
return tf.Variable(init)
def conv2d(x,w):
return tf.nn.conv2d(x,w,[1,1,1,1],padding="SAME")
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME") def deepnn(x):
with tf.name_scope('reshape'):
x_image=tf.reshape(x,[-1,28,28,1])
#第一层卷积和池化
with tf.name_scope('conv1'):
#输入为1张图片 卷积核为5*5 生成32个特征图
w_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
with tf.name_scope('pool1'):
h_pool1=max_pool_2x2(h_conv1)
#第二层卷积和池化
with tf.name_scope("conv2"):
#输入为32张特征图,卷积核为5*5 输出64张特征图
w_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
with tf.name_scope("pool2"):
h_pool2=max_pool_2x2(h_conv2)
#第一层全连接层,将特征图展开为特征向量,与1024个节点连接
with tf.name_scope("fc1"):
w_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
#dropout层,训练时随机让某些隐含层节点权重不工作
with tf.name_scope("dropout1"):
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
#第二个全连接层,连接1024个节点,输出one-hot预测
with tf.name_scope("fc2"):
w_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
h_fc2=tf.matmul(h_fc1_drop,w_fc2)+b_fc2
return h_fc2,keep_prob

训练代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import argparse
import sys
import tempfile from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf import mnist_model FLAGS = None
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
#设置输入变量
x=tf.placeholder(dtype=tf.float32,shape=[None,784])
#设置输出变量
y_real=tf.placeholder(dtype=tf.float32,shape=[None,10])
#实例化网络
y_pre,keep_prob=mnist_model.deepnn(x)
#设置损失函数
with tf.name_scope("loss"):
cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=y_pre,labels=y_real)
loss=tf.reduce_mean(cross_entropy)
#设置优化器
with tf.name_scope("adam_optimizer"):
train_step=tf.train.AdamOptimizer(1e-4).minimize(loss)
#计算正确率:
with tf.name_scope("accuracy"):
correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1))
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)
#将神经网络图模型保存
graph_location=tempfile.mkdtemp()
print('saving graph to %s'%graph_location)
train_writer=tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph())
#将训练的网络保存下来
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=mnist.train.next_batch(50)
if i%100==0:
train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_real: batch[1], keep_prob: 1.0})
print('step %d, training accuracy %g' % (i, train_accuracy))
sess.run(train_step,feed_dict={x: batch[0], y_real: batch[1], keep_prob: 0.5})
#在测试集上进行测试
test_accuracy = 0
for i in range(200):
batch = mnist.test.next_batch(50)
test_accuracy += accuracy.eval(feed_dict={x: batch[0], y_real: batch[1], keep_prob: 1.0}) / 200; print('test accuracy %g' % test_accuracy)
save_path = saver.save(sess, "mnist_cnn_model.ckpt")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str,
default='./',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

部分训练结果:

step 3600, training accuracy 0.98
step 3700, training accuracy 0.98
step 3800, training accuracy 0.96
step 3900, training accuracy 1
step 4000, training accuracy 0.98
step 4100, training accuracy 0.96
step 4200, training accuracy 1
step 4300, training accuracy 1
step 4400, training accuracy 0.98
step 4500, training accuracy 0.98
step 4600, training accuracy 0.98
step 4700, training accuracy 1
step 4800, training accuracy 0.98
step 4900, training accuracy 1
test accuracy 0.9862

使用卷积神经网络CNN训练识别mnist的更多相关文章

  1. 卷积神经网络(CNN)代码实现(MNIST)解析

    在http://blog.csdn.net/fengbingchun/article/details/50814710中给出了CNN的简单实现,这里对每一步的实现作个说明: 共7层:依次为输入层.C1 ...

  2. 基于MNIST数据的卷积神经网络CNN

    基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...

  3. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  4. 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别

    这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...

  5. 深度学习之卷积神经网络(CNN)详解与代码实现(二)

    用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...

  6. 基于卷积神经网络的人脸识别项目_使用Tensorflow-gpu+dilib+sklearn

    https://www.cnblogs.com/31415926535x/p/11001669.html 基于卷积神经网络的人脸识别项目_使用Tensorflow-gpu+dilib+sklearn ...

  7. python机器学习卷积神经网络(CNN)

    卷积神经网络(CNN) 关注公众号"轻松学编程"了解更多. 一.简介 ​ 卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人 ...

  8. 卷积神经网络(CNN,ConvNet)

    卷积神经网络(CNN,ConvNet) 卷积神经网络(CNN,有时被称为 ConvNet)是很吸引人的.在短时间内,变成了一种颠覆性的技术,打破了从文本.视频到语音等多个领域所有最先进的算法,远远超出 ...

  9. TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN

    前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...

随机推荐

  1. C语言学习笔记 (003) - C/C++中的实参和形参(转)

    今天突然看到一道关于形参和实参的题,我居然不求甚解.藐视过去在我的脑海里只有一个参数的概念,对于形参和实参的区别还真的不知道,作为学习了几年C++的人来说,真的深深感觉对不起自己对不起C++老师  T ...

  2. C 简单1

    #include <stdio.h> #define Height 10 int main(){ int width; int clong; int result; printf(&quo ...

  3. EF的表左连接方法Include和Join

    在EF中表连接常用的有Join()和Include(),两者都可以实现两张表的连接,但又有所不同. 例如有个唱片表Album(AlbumId,Name,CreateDate,GenreId),表中含外 ...

  4. 【SqlServer】SqlServer索引的创建、查看、删除

    索引加快检索表中数据的方法,它对数据表中一个或者多个列的值进行结构排序,是数据库中一个非常有用的对象. 索引的创建 #1使用企业管理器创建 启动企业管理器--选择数据库------选在要创建索引的表- ...

  5. 【Spring】Spring+SpringMVC+MyBatis框架的搭建

    1,SSM的简介 SSM(Spring+SpringMVC+MyBatis)框架集由Spring.SpringMVC.MyBatis三个开源框架整合而成,常作为数据源较简单的web项目的框架. 其中s ...

  6. 雷军:重刷ROM的“自我格式化”

    本文来源于:百度百家 作者:金错刀 2014-03-14 10:33:06 最近,跟一个前金山高管聊起雷军,特别是雷军的变化,她的感觉是:雷总岂止是变化,简直是格式化,甚至是把自己重刷了一遍ROM. ...

  7. Windows中"打开方式..."无法指定程序的解决办法

    Windows真DT, 今天升级了vim, 从vim73到vim74, 突然发现右键菜单打开方式中的VIM不见了, 于是手动重新指定到vim74\gvim.exe, 未果, Windows就直接忽略了 ...

  8. 用C写有面向对象特点的程序

    比如在一个项目中,有大量的数据结构,他们都是双向链表,但又想共用一套对链表的操作算法,这怎么做到呢,C中又没有C++中的继承,不然我可以继承一父(类中只有两个指针,一个向前一个向后),而其算法可以写在 ...

  9. Oracle数据库中number类型在java中的使用

    1)如果不指定number的长度,或指定长度n>18 id number not null,转换为pojo类时,为java.math.BigDecimal类型 2)如果number的长度在10 ...

  10. es5 温故而知新 简单继承示例

    // 矩形(构造器/父类) function Rectangle (height, width) { this.height = height; this.width = width; } // 获取 ...