softmax函数的作用

  对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小。例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9的概率值,这些概率值之和为1。预测的结果取最大的概率表示的数字作为这张图片的分类。

可以从下面这张图理解softmax

x1,x2,x3代表输入的值,b1,b2,b3代表类别1,2,3的偏置量,是因为输入的值可能存在无关的干扰量。
将上图写成等式
$$
\left[\begin{matrix}temp_1\\temp_2\\temp_3\end{matrix}\right]
=\left(\begin{matrix}W_{1,1}x_1+W_{1,2}x_2+W_{1,3}x_3+b_1\\
W_{2,1}x_1+W_{2,2}x_2+W_{2,3}x_3+b_2\\
W_{3,1}x_1+W_{3,2}x_2+W_{3,3}x_3+b_3\end{matrix}\right)\\
\left[\begin{matrix}y_1\\y_2\\y_3\end{matrix}\right]
=softmax\left(\begin{matrix}temp_1\\
temp_2\\
temp_3\end{matrix}\right)\\
其中y_i = softmax(temp_i) = \frac{exp(temp_i)}{\sum_{j=0}^{n}exp(temp_j)}\\
y_1,y_2,y_3分别表示该样本属于类别1,2,3的概率值。
$$
  在神经网络中,通过训练集训练模型中的权重值W和偏置值b,来提高分类的准确性。
(训练方法是定义一个损失函数(代表预测值与真实值之间的差异),然后采用梯度下降法(通过对W,b求偏导)来最小化这个损失函数,具体过程有点复杂,下面只是直接拿tensorflow的函数来实现,后面有空的话再来补充原理)

用Tensorflow实现手写数字识别

首先从tensorflow导入mnist数据集,里面包含了手写数字图片的像素矩阵,以及这些图片所对应的数字类别:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

说明一下图片的像素矩阵是将28x28压平为[1x784]大小的向量;标签是[1x10]的向量,其中某一个数是1,其余全为0,比如说如果标签表示的是数字5,那么这个标签向量为[0,0,0,0,1,0,0,0,0,0]。

构建模型:

x = tf.placeholder("float",[None,784])
#一个二维向量的占位符,None表示第一位可以是任意长度,784表示一张图片压平后的长度
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10])) #temp = x*W + b
#softmax(temp)得到一个[None,10]的向量,表示None个图片可能代表0~9的概率。
y = tf.nn.softmax(tf.matmul(x,W)+b)

构建模型训练过程:定义损失函数,最小化这个损失函数,从而得到W,b

y_ = tf.placeholder("float",[None,10])
#这里用占位符来代表y_(每个图片的真实类别),后面运行时会将真实类别填给占位符。
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#y是模型的预测类别,y_是真实类别,用交叉熵来代表损失函数(说明预测值和真实值之间的差异)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#用梯度下降法来最小化损失函数

运行之前构造的模型:

init = tf.initialize_all_variables()#init表示初始化所有变量
sess = tf.Session()#启动会话,用于运行模型
sess.run(init)#运行init才真正的使所有变量初始化
for i in range(1000):#训练模型1000遍
batch_xs,batch_ys = mnist.train.next_batch(100)
#从数据集中取出100个样本
sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})
#将样本填入之前定义的占位符,然后运行刚才构建的训练过程

评估模型:

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#逐个判断预测值和真实值是否相等,返回一个矩阵。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#tf.cast将bool型转化为float型,reduce_mean计算平均值(即正确率)
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
#将测试集填入之前的占位符,运行之前的模型,得到正确率

输出结果为:

0.9181

总结

  tensorflow让用户先从更大的层面上构建模型,其中需要的数据先由占位符代替,然后在运行模型时再填入对应的数据。用户不需要对具体运算过程一步步编程实现,使得神经网络的构建简便了许多。



正在学习tensorflows时写的笔记,欢迎评论探讨!

参考网址:tensorflow中文社区

Softmax用于手写数字识别(Tensorflow实现)-个人理解的更多相关文章

  1. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

  2. MNIST手写数字识别 Tensorflow实现

    def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 1. strides在官方定义中是一 ...

  3. 基于多层感知机的手写数字识别(Tensorflow实现)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  4. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  5. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  6. 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)

    通过: 手写数字识别  ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别  ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...

  7. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  8. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  9. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

随机推荐

  1. Python基础语法总结【新手必学】

      前言本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理.作者:weixin_45189038直接上知识点: 1. 注释 单行注释: ...

  2. 笔记||Python3之文件的读写

    [文件的读模式]            文件的对象:文件的读写通过文件操作对象进行. Python2  -----  File Python3  -----  TextIOWrapper       ...

  3. 使用java语言实现八皇后问题

    八皇后问题,在一个8X8的棋盘中,放置八个棋子,每个棋子的上下左右,左上左下,右上右下方向上不得有其他棋子.正确答案为92中,接下来用java语言实现. 解: package eightQuen; / ...

  4. 【系列专题】ECMAScript 重温系列(10篇全)

    ES6 系列ECMAScript 2015 [ES]150-重温基础:ES6系列(一) [ES]151-重温基础:ES6系列(二) [ES]152-重温基础:ES6系列(三) [ES]153-重温基础 ...

  5. win10配置git SSH

    1.安装的过程就不说了,直接去官网下载git for windows 安装便可 安装完了,无非就是像用它,就想从github上clone项目下来,仅仅是安装了git还不能直接从远程下载项目下来哦,还需 ...

  6. Weed3 for java 新的微型ORM框架

    Weed3,微型ORM框架(支持:java sql,xml sql,annotation sql:存储过程:事务:缓存:监听:等...) 05年时开发了第一代: 08年时开发了第二代,那时候进入互联网 ...

  7. ThinkPhp RBAC实现原理

    RBAC是英文Role-Based Access Control的缩写,是基于角色访问进行控制的机制.意思是给每个用户设定一个角色,然后根据这个角色来判断用户的权限. 在此基于ThinkPhp的MVC ...

  8. Redis Cluster 的数据分片机制

    上一篇<分布式数据缓存中的一致性哈希算法> 文章中讲述了一致性哈希算法的基本原理和实现,今天就以 Redis Cluster 为例,详细讲解一下分布式数据缓存中的数据分片,上线下线时数据迁 ...

  9. PythonI/O进阶学习笔记_8.python的可迭代对象和迭代器、迭代设计模式

     content: 1.什么是迭代协议 2. 什么是迭代器(Iterator)和可迭代对象(Iterable) 3. 使用迭代器和可迭代对象 4. 创建迭代器和可迭代对象 5. 迭代器设计模式   一 ...

  10. Blockchain 基本知识

    本文是前奏,本来要介绍Azure上的Azure Blockchain Service,发现,需要从什么是区块链开始讲起... 什么是区块链?我们从比特币说起, 2008年11月,中本聪提出了比特币白皮 ...