import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os mnist = input_data.read_data_sets('MNIST_data', one_hot=True) class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层 # 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss) # 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) # 用于保存训练好的模型
self.saver = tf.train.Saver() def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型 def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy))) model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型

基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  3. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  4. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  5. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  6. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  7. Keras mlp 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...

  8. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  9. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

随机推荐

  1. firefox浏览器testclient测试接口

  2. Android-Kotlin-when&类型推断

    Kotlin的when表达式 TextEngine 描述文字处理对象: package cn.kotlin.kotlin_base02 /** * 描述文字处理对象 * * val textConte ...

  3. Android-Kotlin-抽象类与多态的表现

    选择包名,然后右键: 选择Class类型,会有class:  选择File类型,不会自动有class: 目录结构: 定义描述抽象类 Person人类: package cn.kotlin.kotlin ...

  4. [面试题目]IT面试中的一些基础问题

    1. 面向对象的特征 继承,封装,多态 2. 重写和重载的区别 重写:在继承当中,子类重写父类的函数,函数声明完全一样,只是函数里面的操作不一样,这样叫做重写. 重载:与多态无关,即两个函数名一样的成 ...

  5. BZOJ百题版切计划(不咕)

    传送门 BZOJ 前言 听说最近要省选,那么我就写一下吧.QwQ! 1000 过于简单,不写了. 1001 不会对偶图,直接优化最小割 题解 1002 高精度套公式计算 题解 (Code by hey ...

  6. C语言实现windows进程遍历

    #include <windows.h> #include <tlhelp32.h> //进程快照函数头文件 #include <stdio.h> int main ...

  7. mac终端常用命令

    1.du #查看文件目录大小 示例:查看DataCenter目录下所有文件/文件夹的大小 everSeeker:DataCenter pingping$ -h .9G ./Books 1.2M ./C ...

  8. 修改windows远程默认端口

    修改windows远程默认端口 windows端口修改rdp 1 远程服务器运行窗口调出注册表编辑器 注册表编辑器regeidt 2 修改两个注册表 1,在注册表HKEY_LOCAL_MACHINE\ ...

  9. JS简单表单验证

    这里我是写了一个简单的注册表单验证功能,亲测有效,一起来看看吧! 首先我的HTML代码是这样的: class大家可以忽略一下,这里我项目使用的是bootstrap的样式. 输入用户名和密码用的是正则表 ...

  10. javascript 计算文件MD5 浏览器 javascript读取文件内容

    原则上说,浏览器是一个不安全的环境.早期浏览器的内容是静态的,用户上网冲浪,一般就是拉取网页查看.后来,随着互联网的发展,浏览器提供了非常丰富的用户交互功能.从早期的表单交互,到现在的websocke ...