基于TensorFlow的MNIST手写数字识别-初级
一:MNIST数据集
MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件
分别是test set images,test set labels,training set images,training set labels
training set包括60000个样本,test set包括10000个样本。
test set中前5000个样本来自原始的NISTtraining set,后5000个样本来自原始的NIST test set,因此,前5000个样本比后5000个样本更简单和干净。
每个样本是28*28像素的图片
二:tensorflow构建模型识别MNIST
导入数据:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10]) #真实值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #预测值
softmax的目的:将输出转化为是每个数字的概率
#计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
交叉熵:衡量预测值与真实值之间的差别,当然是越小越好
公式为:
其中y'是真实值,y为预测值
最后用梯度下降法优化参数即可
在Session中运行graph:
total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y})
预测正确率:
correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.argmax()函数返回axis轴上最大值的index
tf.equal()函数返回的是布尔值,需要用tf.cast()方法转为tf.float32类型
最后在test set上进行预测:
step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
完整代码如下:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#eval
correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y}) step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
运行结果:
准确率为0.92左右
后面我们会构建更好的模型达到更高的正确率。
相关链接:
基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型
基于tensorflow的MNIST手写数字识别(二)--入门篇
基于tensorflow的MNIST手写数字识别(三)--神经网络篇
基于TensorFlow的MNIST手写数字识别-初级的更多相关文章
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 基于TensorFlow的MNIST手写数字识别-深入
构建多层卷积神经网络时需要多组W和偏移项b,我们封装2个方法来产生W和b 初级MNIST中用0初始化W和b,这里用噪声初始化进行对称打破,防止产生梯度0,同时用一个小的正值来初始化b避免dead ne ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- Tensorflow之MNIST手写数字识别:分类问题(2)
整体代码: #数据读取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorfl ...
- TensorFlow——MNIST手写数字识别
MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
随机推荐
- NB的程序员,亮瞎了你的眼吗?
郑重声明: 本文首发于人工博客 1.导读 你能想象到1K的代码能写出什么样的功能强大.效果炫酷的作品吗?来吧,今天小编带领大家认识下下面这位大神的作品. 西班牙程序员Roman Cortes用纯Jav ...
- Intellij IDEA2019.1.3破解
下载 JetbrainsCrack.jar(链接:https://pan.baidu.com/s/1Dkw1PruzBlEMjcYszNlSZA 提取码:2bf7),放到bin目录下(其实位置可以随便 ...
- POJ 3304 Segments(判断直线与线段是否相交)
题目传送门:POJ 3304 Segments Description Given n segments in the two dimensional space, write a program, ...
- 小白学 Python 爬虫(37):爬虫框架 Scrapy 入门基础(五) Spider Middleware
人生苦短,我用 Python 前文传送门: 小白学 Python 爬虫(1):开篇 小白学 Python 爬虫(2):前置准备(一)基本类库的安装 小白学 Python 爬虫(3):前置准备(二)Li ...
- c++ beep 演奏一次质量不高的天空之城
beep函数用法: beep(HZ,time); hz是发出多少赫兹声音,time是发声时间(ms) 话不多说,上代码 #include <cstdio> #include <win ...
- 【一起学源码-微服务】Hystrix 源码一:Hystrix基础原理与Demo搭建
说明 原创不易,如若转载 请标明来源! 欢迎关注本人微信公众号:壹枝花算不算浪漫 更多内容也可查看本人博客:一枝花算不算浪漫 前言 前情回顾 上一个系列文章讲解了Feign的源码,主要是Feign动态 ...
- 探究公钥、私钥、对称加密、非对称加密、hash加密、数字签名、数字证书、CA认证、https它们究竟是什么,它们分别解决了通信过程的哪些问题。
一.准备 1. 角色:小白.美美.小黑. 2. 剧情:小白和美美在谈恋爱:小黑对美美求而不得.心生怨念,所以从中作梗. 3. 需求:小白要与美美需通过网络进行通信,联络感情,所以必须保证通信的安全性. ...
- cogs 1176. [郑州101中学] 月考 Map做法
1176. [郑州101中学] 月考 ★★☆ 输入文件:mtest.in 输出文件:mtest.out 简单对比时间限制:1 s 内存限制:128 MB [题目描述] 在上次的月考中B ...
- 如何添加.pch文件
1.Create a pch , call name is project+xxx.pch For example: DuoME-PrefixHeader.pch 2.在project——>Bu ...
- proxy应用场景
//场景一:可以修改对象的值let o = { name: 'xiaoming', price: 190 } let d = new Proxy(o,{ get (target,key){ if(ke ...