基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
%matplotlib inline
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
class ConvModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
self.X_flat = tf.placeholder(tf.float32, [None, 784])
self.X = tf.reshape(self.X_flat, [-1, 28, 28, 1]) # 本次要用卷积进行运算,所以使用2维矩阵。从这个角度讲,利用了更多的位置信息。
self.y = tf.placeholder(tf.float32, [None, 10])
self.dropRate = tf.placeholder(tf.float32)
conv1 = tf.layers.conv2d(self.X, 32, 5, padding='same', activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1, seed=0),
bias_initializer=tf.constant_initializer(0.1))
conv1 = tf.layers.max_pooling2d(conv1 , 2,2)
conv2 = tf.layers.conv2d(conv1, 64, 5, padding='same', activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1, seed=0),
bias_initializer=tf.constant_initializer(0.1))
pool1 = tf.layers.max_pooling2d(conv2, 2,2)
flatten = tf.reshape(pool1 , [-1, 7*7*64])
dense1 = tf.layers.dense(flatten, 1024, activation=tf.nn.relu, use_bias=True,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1, seed=0),
bias_initializer=tf.constant_initializer(0.1))
dense1_ = tf.nn.dropout(dense1, self.dropRate)
dense2 = tf.layers.dense(dense1_, 10, activation=tf.nn.relu, use_bias=True,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1, seed=0),
bias_initializer=tf.constant_initializer(0.1))
self.loss = tf.losses.softmax_cross_entropy(onehot_labels=self.y, logits=dense2)
self.train_step = tf.train.AdamOptimizer(1e-4).minimize(self.loss )
# 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(dense2, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
# 用于保存训练好的模型
self.saver = tf.train.Saver()
def train(self):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _= sess.run([self.loss, self.train_step],
feed_dict={self.X_flat: batch_x, self.y: batch_y, self.dropRate: 0.5}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.X_flat: batch_x, self.y: batch_y, self.dropRate: 1.}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.X_flat: test_x, self.y: test_y, self.dropRate: 1.}) # 把测试集数据装填进去
print ('iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型
def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(int(10000/self.batch_size)):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.X_flat: test_x, self.y: test_y, self.dropRate: 1.})
Accuracy.append(test_accuracy)
print('==' * 15)
print( 'Test Accuracy: ', np.mean(np.array(Accuracy)) )
model = ConvModel(0.001, 64, 30000) # 学习率为0.001,每批传入64张图,训练30000次
model.train() # 训练模型
model.test() # 预测
基于卷积神经网络的手写数字识别分类(Tensorflow)的更多相关文章
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- TensorFlow(十):卷积神经网络实现手写数字识别以及可视化
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)
莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...
随机推荐
- AngularJS 服务 provider factory service及区别
一.概念说明 1.服务是对公共代码的抽象,如多个控制器都出现了相似代码,把他们抽取出来,封装成一个服务,遵循DRY原则,增强可维护性,剥离了和具体表现相关的部分,聚焦于业务逻辑或交互逻辑,更加容易被测 ...
- 在Windows7系统上能正常使用的程序,Windows10运行后部分状态不能及时变更
这是最近在开发一个通信项目时遇到的问题,一开始以为是窗体样式的原因,把窗体换成系统窗体之后还是在Win10上不能正常使用,后面突然想到会不会是匹配原因,试了一下,结果真的就正常了. 问题:例如一个通信 ...
- AEAI Portal 权限体系说明
1.概述 在数通畅联的产品体系中,AEAI Portal毫无疑问的占据了很重要的地位,在这里我们将通过参考Portal样例,讲述一下AEAI Portal权限体系的控制方法.在Portal使用过程中, ...
- 【mock】后端不来过夜半,闲敲mock落灯花 (mockjs+Vuex+Vue实战)
mock的由来[假] 赵师秀:南宋时期的一位前端工程师 诗词背景:在一个梅雨纷纷的夜晚,正处于项目编码阶段,书童却带来消息:写后端的李秀才在几个时辰前就赶往临安度假去了,!此时手头仅有一个简单 ...
- 多项式&生成函数(~~乱讲~~)
多项式 多项式乘法 FFT,NTT,MTT不是前置知识吗?随便学一下就好了(虽然我到现在还是不会MTT,exlucas也不会用) FTT总结 NTT总结 泰勒展开 如果一个多项式\(f(x)\)在\( ...
- MariaDB 表的基本操作(3)
MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可MariaDB的目的是完全兼容MySQL,包括API和命令行,MySQL由于现在闭源了,而能轻松成为MySQ ...
- [HTML] 动态修改input placeholder的颜色
.invalid:-moz-placeholder { /* Mozilla Firefox 4 to 18 */ color: red; } .invalid::-moz-placeholder { ...
- 【UML】:类图
1 实线/虚线 + 三角空心箭头: 继承extends:实线,三角空心箭头指向父类,子类指向父类,子类 is a 父类. 实现implements:虚线,三角空心箭头指向接口,类指向接口,类 实现 ...
- vue教程2-06 过滤器
vue教程2-06 过滤器 过滤器: vue提供过滤器: capitalize uppercase currency.... <div id="box"> {{msg| ...
- odoo开发笔记 -- 前台不同视图访问同一个模型
看一下partner这个表, 客户和供应商,都用这个表,那怎么区分呢: 供应商: 客户 注意这两个里面用domain来进行区分: <field name="domain" ...