import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层 # 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss) # 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) # 用于保存训练好的模型
self.saver = tf.train.Saver() def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型 def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy))) model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型

基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  5. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  6. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  7. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

随机推荐

  1. WPF 嵌入Winform GDI 、 开启AllowsTransparenc问题

    此文章可以解决2至少2个问题: 1.开启AllowsTransparenc造成的GDI+组件不显示问题 2.WPF 组件无法覆盖嵌入WPF窗口的任何第三方GDI+组件上层 方案1:自制双层 原理:用一 ...

  2. XML随笔:语法快速入门及当下流行的RSS简析

    今天是本人第一次写博客,之前闭门造车闹出过很多笑话,恰巧这几天刚刚重温了一遍XML的知识,决定把XML的知识再来从头到尾的理一遍,感触颇多,今天分享给大家.希望大家能多多注意其中的要点. 1.定义 首 ...

  3. sharepoint support ashx file

    Hello, I did the steps from the tutorial you are using. I have received the same error when I did no ...

  4. Docker Compose模板文件介绍

    模板文件是使用 Compose 的核心,涉及到的指令关键字也比较多,这里面大部分指令跟 docker run 相关参数的含义都是类似的.默认的模板文件名称为 docker-compose.yml ,格 ...

  5. Linux巩固记录(4) 运行hadoop 2.7.4自带demo程序验证环境

    本节主要使用hadoop自带的程序运行demo来确认环境是否正常 1.首先创建一个input.txt文件,里面任意输入些单词,有部分重复单词 2.将input文件拷贝到hdfs 3.执行hadoop程 ...

  6. git log 高级用法

    转自:https://github.com/geeeeeeeeek/git-recipes/wiki/5.3-Git-log%E9%AB%98%E7%BA%A7%E7%94%A8%E6%B3%95 内 ...

  7. tomcat服务的启动与隐藏启动(win)

    一:  tomcat的启动与隐藏启动 1. 正常启动:D:\apache-tomcat-8.5.24\bin中的   startup.bat  双击启动 2. 启动tomcat服务后,window下方 ...

  8. python 相关模块安装 国内镜像地址

    python 相关模块安装 国内镜像地址 pipy国内镜像目前有: http://pypi.douban.com/  豆瓣 http://pypi.hustunique.com/  华中理工大学 ht ...

  9. 安装eclipse启动时报错

    1.在安装eclipse后,点击exe文件时,提示出现错误,记录在log文件中,因为log文件就是日志文件,可以方便我们排查错误,打开log文件,可以看到文件记录了每次出错的时间和错误栈信息,最新一次 ...

  10. 了解fortran语言

    最近看了一些文献,发现用了Fortran语言编程,并且还是近几年的,了解了之后才知道,其实Fortran已经慢慢没有人再用了,之所有还有一批人在用,极大可能是历史遗留问题吧.而这,也得从Fortran ...