import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层 # 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss) # 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) # 用于保存训练好的模型
self.saver = tf.train.Saver() def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型 def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy))) model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型

基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  5. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  6. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  7. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

随机推荐

  1. To datafix AR DATE

    http://www.cnblogs.com/benio/archive/2012/07/07/2580203.html AR transactions should be created on 6- ...

  2. Windows核心编程:第11章 Windows线程池

    Github https://github.com/gongluck/Windows-Core-Program.git //第11章 Windows线程池.cpp: 定义应用程序的入口点. // #i ...

  3. ASP.NET判断是否为手机登录

    protected void Page_Load(object sender, EventArgs e) { MobileHandle(); } 页面加载时候判断是否为手机登录 protected v ...

  4. cesiumjs

    官网 https://cesium.com/ https://cesiumjs.org/Cesium/ 论坛 http://cesium.coinidea.com/ 中文网 http://cesium ...

  5. dbcp和druid(数据库连接池)

    前言: 对于一个简单的数据库应用,由于对于数据库的访问不是很频繁.这时可以简单地在需要访问数据库时,就新创建一个连接,用完后就关闭它,这样做也不会带来什么明显的性能上的开销.但是对于一个复杂的数据库应 ...

  6. ionic3.x angular4.x ng4.x 自定义组件component双向绑定之自定义计数器

    本文主要示例在ionic3.x环境下实现一个自定义计数器,实现后最终效果如图: 1.使用命令创建一个component ionic g component CounterInput 类似的命令还有: ...

  7. Android------------------RecyclerView学习

    一.多种布局的保存 1.Type->getItemViewType(int position) 2.RecylerView.Holder       : 定一个holder的内部类,里面保存一些 ...

  8. delphi datetimepicker 修改时间无效问题

    今天用delphi写个程序,使用datetimepicker获得想要的时间.蛋疼的问题是无论怎么调整明明看着控件里面的日期变了,但是show出来的datetimepicker.datetime日期都不 ...

  9. 微服务统一登陆认证怎么做?JWT ?

    无状态登录原理 1.1.什么是有状态? 有状态服务,即服务端需要记录每次会话的客户端信息,从而识别客户端身份,根据用户身份进行请求的处理,典型的设计如tomcat中的session. 例如登录:用户登 ...

  10. 使用Flask+uwsgi+Nginx部署Flask正式环境

    环境准备 在开始正式讲解之前,我们将首先进行环境准备. Step1:安装Python,pip以及nginx: sudo apt-get update sudo apt-get install pyth ...