使用卷积神经网络CNN训练识别mnist
算的的上是自己搭建的第一个卷积神经网络。网络结构比较简单。
输入为单通道的mnist数据集。它是一张28*28,包含784个特征值的图片
我们第一层输入,使用5*5的卷积核进行卷积,输出32张特征图,然后使用2*2的池化核进行池化 输出14*14的图片
第二层 使用5*5的卷积和进行卷积,输出64张特征图,然后使用2*2的池化核进行池化 输出7*7的图片
第三层为全连接层 我们总结有 7*7*64 个输入,输出1024个节点 ,使用relu作为激活函数,增加一个keep_prob的dropout层
第四层为输出层,我们接收1024个输入,输出长度为10的one-hot向量。使用softmax作为激活函数
使用交叉熵作为损失函数
网络模型代码:
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import sys
- import tempfile
- from tensorflow.examples.tutorials.mnist import input_data
- import tensorflow as tf
- FLAGS = None
- def weight_variable(shape):
- init=tf.truncated_normal(shape=shape,stddev=0.1,mean=1.)
- return tf.Variable(init)
- def bias_variable(shape):
- init=tf.constant(0.1,shape=shape)
- return tf.Variable(init)
- def conv2d(x,w):
- return tf.nn.conv2d(x,w,[1,1,1,1],padding="SAME")
- def max_pool_2x2(x):
- return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
- def deepnn(x):
- with tf.name_scope('reshape'):
- x_image=tf.reshape(x,[-1,28,28,1])
- #第一层卷积和池化
- with tf.name_scope('conv1'):
- #输入为1张图片 卷积核为5*5 生成32个特征图
- w_conv1=weight_variable([5,5,1,32])
- b_conv1=bias_variable([32])
- h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
- with tf.name_scope('pool1'):
- h_pool1=max_pool_2x2(h_conv1)
- #第二层卷积和池化
- with tf.name_scope("conv2"):
- #输入为32张特征图,卷积核为5*5 输出64张特征图
- w_conv2=weight_variable([5,5,32,64])
- b_conv2=bias_variable([64])
- h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
- with tf.name_scope("pool2"):
- h_pool2=max_pool_2x2(h_conv2)
- #第一层全连接层,将特征图展开为特征向量,与1024个节点连接
- with tf.name_scope("fc1"):
- w_fc1=weight_variable([7*7*64,1024])
- b_fc1=bias_variable([1024])
- h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
- h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
- #dropout层,训练时随机让某些隐含层节点权重不工作
- with tf.name_scope("dropout1"):
- keep_prob=tf.placeholder(tf.float32)
- h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
- #第二个全连接层,连接1024个节点,输出one-hot预测
- with tf.name_scope("fc2"):
- w_fc2=weight_variable([1024,10])
- b_fc2=bias_variable([10])
- h_fc2=tf.matmul(h_fc1_drop,w_fc2)+b_fc2
- return h_fc2,keep_prob
训练代码:
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import sys
- import tempfile
- from tensorflow.examples.tutorials.mnist import input_data
- import tensorflow as tf
- import mnist_model
- FLAGS = None
- def main(_):
- mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
- #设置输入变量
- x=tf.placeholder(dtype=tf.float32,shape=[None,784])
- #设置输出变量
- y_real=tf.placeholder(dtype=tf.float32,shape=[None,10])
- #实例化网络
- y_pre,keep_prob=mnist_model.deepnn(x)
- #设置损失函数
- with tf.name_scope("loss"):
- cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=y_pre,labels=y_real)
- loss=tf.reduce_mean(cross_entropy)
- #设置优化器
- with tf.name_scope("adam_optimizer"):
- train_step=tf.train.AdamOptimizer(1e-4).minimize(loss)
- #计算正确率:
- with tf.name_scope("accuracy"):
- correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1))
- correct_prediction = tf.cast(correct_prediction, tf.float32)
- accuracy = tf.reduce_mean(correct_prediction)
- #将神经网络图模型保存
- graph_location=tempfile.mkdtemp()
- print('saving graph to %s'%graph_location)
- train_writer=tf.summary.FileWriter(graph_location)
- train_writer.add_graph(tf.get_default_graph())
- #将训练的网络保存下来
- saver=tf.train.Saver()
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- for i in range(5000):
- batch=mnist.train.next_batch(50)
- if i%100==0:
- train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_real: batch[1], keep_prob: 1.0})
- print('step %d, training accuracy %g' % (i, train_accuracy))
- sess.run(train_step,feed_dict={x: batch[0], y_real: batch[1], keep_prob: 0.5})
- #在测试集上进行测试
- test_accuracy = 0
- for i in range(200):
- batch = mnist.test.next_batch(50)
- test_accuracy += accuracy.eval(feed_dict={x: batch[0], y_real: batch[1], keep_prob: 1.0}) / 200;
- print('test accuracy %g' % test_accuracy)
- save_path = saver.save(sess, "mnist_cnn_model.ckpt")
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_dir', type=str,
- default='./',
- help='Directory for storing input data')
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
部分训练结果:
- step 3600, training accuracy 0.98
- step 3700, training accuracy 0.98
- step 3800, training accuracy 0.96
- step 3900, training accuracy 1
- step 4000, training accuracy 0.98
- step 4100, training accuracy 0.96
- step 4200, training accuracy 1
- step 4300, training accuracy 1
- step 4400, training accuracy 0.98
- step 4500, training accuracy 0.98
- step 4600, training accuracy 0.98
- step 4700, training accuracy 1
- step 4800, training accuracy 0.98
- step 4900, training accuracy 1
- test accuracy 0.9862
使用卷积神经网络CNN训练识别mnist的更多相关文章
- 卷积神经网络(CNN)代码实现(MNIST)解析
在http://blog.csdn.net/fengbingchun/article/details/50814710中给出了CNN的简单实现,这里对每一步的实现作个说明: 共7层:依次为输入层.C1 ...
- 基于MNIST数据的卷积神经网络CNN
基于tensorflow使用CNN识别MNIST 参数数量:第一个卷积层5x5x1x32=800个参数,第二个卷积层5x5x32x64=51200个参数,第三个全连接层7x7x64x1024=3211 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别
这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(二)
用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...
- 基于卷积神经网络的人脸识别项目_使用Tensorflow-gpu+dilib+sklearn
https://www.cnblogs.com/31415926535x/p/11001669.html 基于卷积神经网络的人脸识别项目_使用Tensorflow-gpu+dilib+sklearn ...
- python机器学习卷积神经网络(CNN)
卷积神经网络(CNN) 关注公众号"轻松学编程"了解更多. 一.简介 卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人 ...
- 卷积神经网络(CNN,ConvNet)
卷积神经网络(CNN,ConvNet) 卷积神经网络(CNN,有时被称为 ConvNet)是很吸引人的.在短时间内,变成了一种颠覆性的技术,打破了从文本.视频到语音等多个领域所有最先进的算法,远远超出 ...
- TensorFlow 2.0 深度学习实战 —— 浅谈卷积神经网络 CNN
前言 上一章为大家介绍过深度学习的基础和多层感知机 MLP 的应用,本章开始将深入讲解卷积神经网络的实用场景.卷积神经网络 CNN(Convolutional Neural Networks,Conv ...
随机推荐
- IO 多路复用是什么意思?
在同一个线程里面, 通过拨开关的方式,来同时传输多个I/O流, (学过EE的人现在可以站出来义正严辞说这个叫“时分复用”了). 什么,你还没有搞懂“一个请求到来了,nginx使用epoll接收请求的过 ...
- hibernate使用注解设置日期默认值
用注解设置属性的默认值时 使用 @Temporal(TemporalType.TIMESTAMP) @Column(updatable = false,nullable=false,length=20 ...
- 日志收集-Flume-ng-mongodb-sink
本文主要介绍使用Flume传输数据到MongoDB的过程,内容涉及环境部署和注意事项. 一.环境搭建 1.flune-ng下载地址:http://www.apache.org/dyn/closer.c ...
- 31天重构学习笔记(java版本)
准备下周分享会的内容,无意间看到.net版本的重构31天,花了两个小时看了下,可以看成是Martin Fowler<重构>的精简版 原文地址:http://www.lostechies.c ...
- 【转】编辑器与IDE
编辑器与IDE 无谓的编辑器战争 很多人都喜欢争论哪个编辑器是最好的.其中最大的争论莫过于 Emacs 与 vi 之争.vi 的支持者喜欢说:“看 vi 打起字来多快,手指完全不离键盘,连方向键都可以 ...
- MySQL 5.7.19 CentOS 7 安装
Linux的版本有很多,因此下载mysql时,需要注意下载对应Linux版本的MySql数据库文件.以下方法也适合centOS 7 的mysql 5.7.* 版本的安装.安装方法我整理为16步. 1: ...
- ThinkPHP 3.2 性能优化,实现高性能API开发
需求分析 目前的业务全站使用ThinkPHP 3.2.3,前台.后台.Cli.Api等.目前的业务API访问量数千万,后端7台PHP 5.6,平均CPU使用率20%. 测试数据 真实业务 php5.6 ...
- asp.net中WinForm使用单例模式示例
例如在Windows应用程序中用下面代码打开一个窗体: 代码如下 复制代码 private void button1_Click(object sender, EventArgs e){ (new A ...
- SQL触发器 常用语句
一.创建一个简单的触发器 CREATE TRIGGER 触发器名称 ON 表名 FOR INSERT.UPDATE 或 DELETE AS T-SQL 语句 注意:触发器名称是不加引号的. ...
- labview中小黑点,小红点
小黑点:在labview中每一个小黑点就代表了一次内存的分配,通过小黑点可以帮助我们分析数据变量的内存拷贝情况