1 MNIST数据集

MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字。原始的MNIST数据库一共包含下面4个文件,见下表。

训练图像一共有60000张,供研究人员训练出合适的模型。测试图像一共有10000张,供研究人员测试训练的模型的性能。

2 Softmax 回归

Softmax回归是一个线性的多类分类模型,实际上它是直接从Logistic回归模型转化而来的。区别在于Logistic 回归模型为两类分类模型,而Softmax 模型为多类分类模型。

在手写体识别问题中,一共有10个类别(0~9),我们希望对输入的图像计算它属于每个类别的概率。如属于9的概率为70%,属于1的概率为10%等。最后模型预测的结果就是概率最大的那个类别。

先来了解什么是Softmax函数。Softmax函数的主要功能是将各个类别的“打分”转化成合理的概率值。例如,一个样本可能属于三个类别:第一个类别的打分为a,第二个类别的打分为b,第三个类别的打分为c。打分越高代表属于这个类别的概率越高,但是打分本身不代表概率,因为打分的值可以是负数,也可以很大,但概率要求值必须在0~1,并且三类的概率加起来应该等于1。那么,如何将(a,b,c)转换成合理的概率值呢?方法就是使用Softmax函数。例如,对(a,b,c)使用Softmax函数后,相应的值会变成如下所示的形式:

这三个数值都在0~1之间,并且加起来正好等于1,是合理的概率表示。

假设x是单个样本的特征,W、b是Softmax模型的参数。在MNIST数据集中,x就代表输入图片,它是一个784维的向量,而W是一个矩阵,它的形状为(784,10),b是一个10维的向量,10代表的是类别数。Softmax模型的第一步是通过下面的公式计算各个类别的Logit:

Logit 同样是一个10维的向量,它实际上可以看成样本对应于各个类别的“打分”。接下来使用Softmax函数将它转换成各个类别的概率值:

3 tensorflow 实现

# -*- coding: utf-8 -*-

import tensorflow as tf

#导入mnist教学的模块
from tensorflow.examples.tutorials.mnist import input_data #读入mnist数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True) #创建x,x是一个占位符,代表待识别的图片
x = tf.placeholder(tf.float32,[None,784]) # w是softmax模型的参数,将一个784的输入转换为一个10位的输出
w = tf.Variable(tf.zeros([784,10])) # b是又一个softmax的参数,一般叫做“偏置项
b = tf.Variable(tf.zeros([10])) # y表示模型的输出
y = tf.nn.softmax(tf.matmul(x,w) + b) # y_是实际的图像标签,同样以占位符表示
y_ = tf.placeholder(tf.float32,[None,10]) #根据y和y_构造交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y))) #有了交叉熵,就可以使用梯度下降法针对模型的参数(w和b)进行优化
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #创建一个Session。只有在Session中才能运行优化步骤train_step
sess = tf.InteractiveSession()
#运行之前必须要初始化所有的变量,分配内存
tf.global_variables_initializer().run() # 进行1000步梯度下降
for _ in range(1000):
#在mnist.train中取100个训练数据
#batch_xs是形状为(100,784)的图像数据,batch_ys是形如(100,10)的实际标签
#batch_xs与batch_ys分别对应着x和y_两个占位符
batch_xs,batch_ys = mnist.train.next_batch(100)
#在Session中运行train_step,运行时要传入占位符的值
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys}) #正确的预测结果
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#计算预测准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#在Session中运行Tensor可以得到Tensor的值
#这里是获取最终模型的准确率
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels}))

注:

利用TensorFlow识别手写的数字---基于Softmax回归的更多相关文章

  1. 利用TensorFlow识别手写的数字---基于两层卷积网络

    1 为什么使用卷积神经网络 Softmax回归是一个比较简单的模型,预测的准确率在91%左右,而使用卷积神经网络将预测的准确率提高到99%. 2 卷积网络的流程 3 代码展示 # -*- coding ...

  2. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  3. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  4. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

  5. 利用Tensorflow实现手写字符识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. TensorFlow下利用MNIST训练模型并识别自己手写的数字

    最近一直在学习李宏毅老师的机器学习视频教程,学到和神经网络那一块知识的时候,我觉得单纯的学习理论知识过于枯燥,就想着自己动手实现一些简单的Demo,毕竟实践是检验真理的唯一标准!!!但是网上很多的与t ...

  7. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  8. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. Map、可变参数、静态导入、Collections、Arrays、集合嵌套

    Map双列集合 key 键 value 值 方法: put(K, V) //判断key值是否相等,相等,覆盖元素,不相等,存储 remove(K) Map集合的遍历(不要看到java提供了很多方法做一 ...

  2. CSS——div内文字的溢出部分用省略号显示

    使得div内文字的溢出部分用省略号显示,可归纳为两种解决办法,一种方法是用CSS解决,另一种方法是js解决. 一.通过CSS控制显示 div内显示一行,超出部分用省略号显示 div内显示多行,超出部分 ...

  3. 第二章计算机网络ios 模型

    机构: ISO国际标准化组织: ITU国际电信联盟: ANSI 美国国家标准委员会: ECMA欧洲计算机制作商协会 ITEF因特网特别任务组. 协议:为计算机网路中进行数据交换而建立的规则,标准或约定 ...

  4. 【SDOI2015】约数个数和

    题面 求\(\sum_{i=1}^n\sum_{j=1}^md(ij)\) \(\leq 50000\)组数据,\(1\leq n,m\leq 50000\). 题目分析 首先,你需要知道一个结论: ...

  5. CF596D Wilbur and Trees

    题意:有一些高度为h的树在数轴上.每次选择剩下的树中最左边或是最右边的树推倒(各50%概率),往左倒有p的概率,往右倒1-p. 一棵树倒了,如果挨到的另一棵树与该数的距离严格小于h,那么它也会往同方向 ...

  6. CF627A Xor Equation

    题意:a+b=s,a^b=x(异或).问有多少有序Z+对(a,b)满足条件. 标程: #include<cstdio> using namespace std; typedef long ...

  7. 解决mysql中无法修改事务隔离级别的问题

    使用SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;修改数据库隔离级别, 然后执行SELECT @@TX_ISOLATION;后发现数据库的隔离级别并 ...

  8. mysql之备份表和备份数据库

    备份表 1.首先创建一个与原来一样的表 create table score2 like score; ###like就是将score表的结构拷贝过来,但是它并不执行数据:也就是说执行完上面的语句之后 ...

  9. 从微服务治理的角度看RSocket、. Envoy和. Istio

    很多同学看到这个题目,一定会提这样的问题:RSocket是个协议,Envoy是一个 proxy,Istio是service mesh control plane + data plane. 这三种技术 ...

  10. BZOJ1912:[APIO2010]patrol巡逻

    Description Input 第一行包含两个整数 n, K(1 ≤ K ≤ 2).接下来 n – 1行,每行两个整数 a, b, 表示村庄a与b之间有一条道路(1 ≤ a, b ≤ n). Ou ...