基于多层感知机的手写数字识别(Tensorflow实现)
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
class MNISTModel(object):
def __init__(self, lr, batch_size, iter_num):
self.lr = lr
self.batch_size = batch_size
self.iter_num = iter_num
# 定义模型结构
# 输入张量,这里还没有数据,先占个地方,所以叫“placeholder”
self.x = tf.placeholder(tf.float32, [None, 784]) # 图像是28*28的大小
self.y = tf.placeholder(tf.float32, [None, 10]) # 输出是0-9的one-hot向量
self.h = tf.layers.dense(self.x, 100, activation=tf.nn.relu, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 一个全连接层
self.y_ = tf.layers.dense(self.h, 10, use_bias=True, kernel_initializer=tf.truncated_normal_initializer) # 全连接层
# 使用交叉熵损失函数
self.loss = tf.losses.softmax_cross_entropy(self.y, self.y_)
self.optimizer = tf.train.AdamOptimizer()
self.train_step = self.optimizer.minimize(self.loss)
# 用于模型训练
self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(self.y_, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
# 用于保存训练好的模型
self.saver = tf.train.Saver()
def train(self):
with tf.Session() as sess: # 打开一个会话。可以想象成浏览器打开一个标签页一样,直观地理解一下
sess.run(tf.global_variables_initializer()) # 先初始化所有变量。
for i in range(self.iter_num):
batch_x, batch_y = mnist.train.next_batch(self.batch_size) # 读取一批数据
loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.x: batch_x, self.y: batch_y}) # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。
if i%1000 == 0:
train_accuracy = sess.run(self.accuracy, feed_dict={self.x: batch_x, self.y: batch_y}) # 把训练集数据装填进去
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y}) # 把测试集数据装填进去
print( 'iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
self.saver.save(sess, 'model/mnistModel') # 保存模型
def test(self):
with tf.Session() as sess:
self.saver.restore(sess, 'model/mnistModel')
Accuracy = []
for i in range(150):
test_x, test_y = mnist.test.next_batch(self.batch_size)
test_accuracy = sess.run(self.accuracy, feed_dict={self.x: test_x, self.y: test_y})
Accuracy.append(test_accuracy)
print ('==' * 15)
print ('Test Accuracy: ', np.mean(np.array(Accuracy)))
model = MNISTModel(0.001, 64, 40000) # 学习率为0.001,每批传入64张图,训练40000次
model.train() # 训练模型
model.test() #测试模型
基于多层感知机的手写数字识别(Tensorflow实现)的更多相关文章
- 基于Numpy的神经网络+手写数字识别
基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...
- Mnist手写数字识别 Tensorflow
Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别
from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...
- MNIST手写数字识别 Tensorflow实现
def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...
- Keras mlp 手写数字识别示例
#基于mnist数据集的手写数字识别 #构造了三层全连接层组成的多层感知机,最后一层为输出层 #基于Keras 2.1.1 Tensorflow 1.4.0 代码: import keras from ...
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...
- Keras cnn 手写数字识别示例
#基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...
随机推荐
- WPF 嵌入Winform GDI 、 开启AllowsTransparenc问题
此文章可以解决2至少2个问题: 1.开启AllowsTransparenc造成的GDI+组件不显示问题 2.WPF 组件无法覆盖嵌入WPF窗口的任何第三方GDI+组件上层 方案1:自制双层 原理:用一 ...
- XML随笔:语法快速入门及当下流行的RSS简析
今天是本人第一次写博客,之前闭门造车闹出过很多笑话,恰巧这几天刚刚重温了一遍XML的知识,决定把XML的知识再来从头到尾的理一遍,感触颇多,今天分享给大家.希望大家能多多注意其中的要点. 1.定义 首 ...
- sharepoint support ashx file
Hello, I did the steps from the tutorial you are using. I have received the same error when I did no ...
- Docker Compose模板文件介绍
模板文件是使用 Compose 的核心,涉及到的指令关键字也比较多,这里面大部分指令跟 docker run 相关参数的含义都是类似的.默认的模板文件名称为 docker-compose.yml ,格 ...
- Linux巩固记录(4) 运行hadoop 2.7.4自带demo程序验证环境
本节主要使用hadoop自带的程序运行demo来确认环境是否正常 1.首先创建一个input.txt文件,里面任意输入些单词,有部分重复单词 2.将input文件拷贝到hdfs 3.执行hadoop程 ...
- git log 高级用法
转自:https://github.com/geeeeeeeeek/git-recipes/wiki/5.3-Git-log%E9%AB%98%E7%BA%A7%E7%94%A8%E6%B3%95 内 ...
- tomcat服务的启动与隐藏启动(win)
一: tomcat的启动与隐藏启动 1. 正常启动:D:\apache-tomcat-8.5.24\bin中的 startup.bat 双击启动 2. 启动tomcat服务后,window下方 ...
- python 相关模块安装 国内镜像地址
python 相关模块安装 国内镜像地址 pipy国内镜像目前有: http://pypi.douban.com/ 豆瓣 http://pypi.hustunique.com/ 华中理工大学 ht ...
- 安装eclipse启动时报错
1.在安装eclipse后,点击exe文件时,提示出现错误,记录在log文件中,因为log文件就是日志文件,可以方便我们排查错误,打开log文件,可以看到文件记录了每次出错的时间和错误栈信息,最新一次 ...
- 了解fortran语言
最近看了一些文献,发现用了Fortran语言编程,并且还是近几年的,了解了之后才知道,其实Fortran已经慢慢没有人再用了,之所有还有一批人在用,极大可能是历史遗留问题吧.而这,也得从Fortran ...