对tf.nn.softmax的理解

转载自律者自由 最后发布于2018-10-31 16:39:40 阅读数 25096  收藏

Softmax的含义:Softmax简单的说就是把一个N*1的向量归一化为(0,1)之间的值,由于其中采用指数运算,使得向量中数值较大的量特征更加明显。

如图所示,在等号左边部分就是全连接层做的事。

  1. W是全连接层的参数,我们也称为权值;W是全连接层的参数,是个T*N的矩阵,这个N和X的N对应,T表示类别数,比如你进行手写数字识别,就是10个分类,那么T就是10。
  2. X是全连接层的输入,也就是特征。从图上可以看出特征X是N1的向量,他就是由全连接层前面多个卷积、激活和池化层处理后得到的;
    举一个例子,假设全连接层前面连接的是一个卷积层,这个卷积层的输出是64个特征,每个特征的大小是7X7,那么在将这些特征输入给全连接层之前会将这些特征通过tf.reshape转化为成N
    1的向量(这个时候N就是64X7X7=3136)。
    我们所说的训练一个网络,对于全连接层而言就是寻找最合适的W矩阵。因此全连接层就是执行WX得到一个T1的向量(也就是图中的logits[T1]),这个向量里面的每个数都没有大小限制的,也就是从负无穷大到正无穷大。然后如果你是多分类问题,一般会在全连接层后面接一个softmax层,这个softmax的输入是T1的向量,输出也是T1的向量(也就是图中的prob[T*1],这个向量的每个值表示这个样本属于每个类的概率),只不过输出的向量的每个值的大小范围为0到1。

现在你知道softmax的输出向量是什么意思了,就是概率,该样本属于各个类的概率!

那么softmax执行了什么操作可以得到0到1的概率呢?先来看看softmax的公式(以前自己看这些内容时候对公式也很反感,不过静下心来看就好了):

公式非常简单,前面说过softmax的输入是WX,假设模型的输入样本是I,讨论一个3分类问题(类别用1,2,3表示),样本I的真实类别是2,那么这个样本I经过网络所有层到达softmax层之前就得到了WX,也就是说WX是一个31的向量,那么上面公式中的aj就表示这个31的向量中的第j个值(最后会得到S1,S2,S3);而分母中的ak则表示31的向量中的3个值,所以会有个求和符号(这里求和是k从1到T,T和上面图中的T是对应相等的,也就是类别数的意思,j的范围也是1到T)。因为e^x恒大于0,所以分子永远是正数,分母又是多个正数的和,所以分母也肯定是正数,因此Sj是正数,而且范围是(0,1)。如果现在不是在训练模型,而是在测试模型,那么当一个样本经过softmax层并输出一个T*1的向量时,就会取这个向量中值最大的那个数的index作为这个样本的预测标签。

因此我们训练全连接层的W的目标就是使得其输出的WX在经过softmax层计算后其对应于真实标签的预测概率要最高。

举个例子:假设你的WX=[1,2,3],那么经过softmax层后就会得到[0.09,0.24,0.67],这三个数字表示这个样本属于第1,2,3类的概率分别是0.09,0.24,0.67。取概率最大的0.67,所以这里得到的预测值就是第三类。

弄懂了softmax,就要来说说softmax loss了。
那softmax loss是什么意思呢?如下:

首先L是损失。Sj是softmax的输出向量S的第j个值,前面已经介绍过了,表示的是这个样本属于第j个类别的概率。yj前面有个求和符号,j的范围也是1到类别数T,因此y是一个1*T的向量,里面的T个值,而且只有1个值是1,其他T-1个值都是0。那么哪个位置的值是1呢?答案是真实标签对应的位置的那个值是1,其他都是0。所以这个公式其实有一个更简单的形式:

当然此时要限定j是指向当前样本的真实标签。

来举个例子吧。假设一个5分类问题,然后一个样本I的标签y=[0,0,0,1,0],也就是说样本I的真实标签是4,假设模型预测的结果概率(softmax的输出)p=[0.1,0.15,0.05,0.6,0.1],可以看出这个预测是对的,那么对应的损失L=-log(0.6),也就是当这个样本经过这样的网络参数产生这样的预测p时,它的损失是-log(0.6)。那么假设p=[0.15,0.2,0.4,0.1,0.15],这个预测结果就很离谱了,因为真实标签是4,而你觉得这个样本是4的概率只有0.1(远不如其他概率高,如果是在测试阶段,那么模型就会预测该样本属于类别3),对应损失L=-log(0.1)。那么假设p=[0.05,0.15,0.4,0.3,0.1],这个预测结果虽然也错了,但是没有前面那个那么离谱,对应的损失L=-log(0.3)。我们知道log函数在输入小于1的时候是个负数,而且log函数是递增函数,所以-log(0.6) < -log(0.3) < -log(0.1)。简单讲就是你预测错比预测对的损失要大,预测错得离谱比预测错得轻微的损失要大。

理清了softmax loss,就可以来看看cross entropy了。
corss entropy是交叉熵的意思,它的公式如下:

是不是觉得和softmax loss的公式很像。当cross entropy的输入P是softmax的输出时,cross entropy等于softmax loss。Pj是输入的概率向量P的第j个值,所以如果你的概率是通过softmax公式得到的,那么cross entropy就是softmax loss。

对tf.nn.softmax的理解的更多相关文章

  1. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  2. tf.nn.softmax(logits,name=None)

    tf.nn.softmax( logits, axis=None, name=None, dim=None #dim在后来改掉了 ) 通过Softmax回归,将logistic的预测二分类的概率的问题 ...

  3. tf.nn.softmax & tf.nn.reduce_sum & tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax softmax是神经网络的最后一层将实数空间映射到概率空间的常用方法,公式如下: \[ softmax(x)_i=\frac{exp(x_i)}{\sum_jexp(x_j ...

  4. tf.nn.softmax 分类

    tf.nn.softmax(logits,axis=None,name=None,dim=None) 参数: logits:一个非空的Tensor.必须是下列类型之一:half, float32,fl ...

  5. TensorFlow 的softmax实例理解

    对于理论,简单的去看一下百度上的说明,这里直接上实例,帮助理解. # softmax函数,将向量映射到0-1的范围内,P=exp(ax)/(sum(exp(a1x)+exp(a2x)+...)) in ...

  6. 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

  7. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  8. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. php--0与空的判断

    使用empty()函数判断,两者都是true $a=0; if(trim($a)=="") { echo '数字0'; }

  2. ffmpeg android移植

    CMake语法简介(androidstudio中利用CMake开发NDK): http://blog.csdn.net/u013718120/article/details/62883711FFmpe ...

  3. 图文并解Word插入修改删除批注

    .插入批注 首先选择对象,比如部分文字[hd1] ,之后执行这样的操作:"插入"→"批注":插入的批注处于编辑状态,可以直接输入批注的文字即可;图解如下: .修 ...

  4. 使用apktool反编译apk文件

    Apktool https://ibotpeaches.github.io/Apktool/install/ 下载地址:Apktool https://bitbucket.org/iBotPeache ...

  5. staruml百度网盘下载

    分享staruml官方百度网盘下载 下载时间:2019年9月4日 21:27:37  StarUML(简称SU),是一种创建UML类图,生成类图和其他类型的统一建模语言(UML)图表的工具.StarU ...

  6. JS截取字符串方法集合

    使用 substring()或者slice()   函数:split() 功能:使用一个指定的分隔符把一个字符串分割存储到数组 例子: str="jpg|bmp|gif|ico|png&qu ...

  7. 在python中使用json

    在服务器和客户端的数据交互的时候,要找到一种数据格式,服务端好处理,客户端也好处理,这种数据格式应该是一种统一的标准,不管在哪里端处理起来都是统一的,现在这种数据格式非常的多,比如最早的xml,再后来 ...

  8. Android App 测试工具及知识大集合

    简介: 作者从事测试将近11年,有8年的团队管理经验,经历了上市公司,外包,日企,股份制公司的企业文化洗礼,擅长测试团队的组建,流程建立,改造,质量体系建建设,有三次经历在不同企业文化从"0 ...

  9. 查漏补缺:进程间通信(IPC):管道

    管道是UNIX系统IPC的最古老形式,所有UNIX系统都提供此种通信机制.管道有以下两种局限性: (1)历史上,管道是半双工的(即数据只能在一个方向上流动). (2)管道只能在具有公共先祖的两个进程之 ...

  10. 为什么我们要让人工智能玩游戏:微软Project AIX

    <我的世界>游戏 2016年7月注:Project AIX已正式更名为Project Malmo 注:本文编译自Project AIX: Using Minecraft to build ...