翻译自:https://stackoverflow.com/questions/34240703/whats-the-difference-between-softmax-and-softmax-cross-entropy-with-logits

问题:

在Tensorflow官方文档中,他们使用一个关键词,称为logits。这个logits是什么?比如说在API文档中有很多方法(methods),经常像下面这么写:

tf.nn.softmax(logits, name=None)

另外一个问题是,有2个方法我不知道该怎么区分,它们是:

tf.nn.softmax(logits, name=None)
tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

它们之间的区别是什么?

回答:

简短版本:

假设你有2个tensors,其中y_hat包含每个类预测的得分(比如说,从y = W*x +b计算得到),y_true包含one-hot编码后的正确的label。

y_hat  = ... # Predicted label, e.g. y = tf.matmul(X, W) + b
y_true = ... # True label, one-hot encoded

如果你将y_hat的得分解释为未归一化的log概率,那么它们就是logits

另外,总的交叉熵损失可以用如下方式计算得到:

y_hat_softmax = tf.nn.softmax(y_hat)
total_loss = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), [1]))

从本质上来说,与用softmax_cross_entropy_with_logits()函数计算得到的总的交叉熵损失是一样的,计算方法为:

total_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))

完整版:

举个例子,我们创建一个2x3大小的y_hat,其中行对应着训练样本,列对应类别。因此,这里有2个训练样本和3个类别。

import tensorflow as tf
import numpy as np sess = tf.Session() # Create example y_hat.
y_hat = tf.convert_to_tensor(np.array([[0.5, 1.5, 0.1],[2.2, 1.3, 1.7]]))
sess.run(y_hat)
# array([[ 0.5, 1.5, 0.1],
# [ 2.2, 1.3, 1.7]])

注意到这些值并没有归一化(每一行加起来并不等于1)。为了归一化这些数,我们可以使用softmax函数,这个函数的输入就是未归一化的log概率(也称为logits),输出是归一化的线性概率。

y_hat_softmax = tf.nn.softmax(y_hat)
sess.run(y_hat_softmax)
# array([[ 0.227863 , 0.61939586, 0.15274114],
# [ 0.49674623, 0.20196195, 0.30129182]])

完全理解softmax输出的内容是很重要的。下面我将展示一个表格来更加清楚地解释上面的输出。从表格中可以看出来,训练样本实例1属于类别2的概率是0.619,每个训练样本实例的类概率被归一化,所有每一行的和是1.0。

                      Pr(Class 1)  Pr(Class 2)  Pr(Class 3)
--------------------------------------
Training instance 1 | 0.227863 | 0.61939586 | 0.15274114
Training instance 2 | 0.49674623 | 0.20196195 | 0.30129182

现在我们有了每个训练样本在每个类上的概率,我们可以对每一行使用argmax()来产生一个最终的分类结果。从上面的表格上来看,我们可以判断出训练样本实例1属于类别2,训练样本实例2属于类别1。

那么,这些分类正确么?我们需要根据训练样本正确的标签来衡量。你需要一个one-hot编码的y_true数组,其中每一行表示训练样本实例,每一列表示类别。下面我将创建一个例子,y_true为one-hot编码的数组,其中对于训练样本1正确的标签是类别2,对于训练样本2正确的标签是类别3。

y_true = tf.convert_to_tensor(np.array([[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]]))
sess.run(y_true)
# array([[ 0., 1., 0.],
# [ 0., 0., 1.]])

y_hat_softmax的概率分布接近y_true的概率分布么?我们可以使用交叉熵损失( cross-entropy loss)来衡量错误程度。

根据下面的式子,我们可以计算出每一行的交叉熵损失。从下面的结果中可以看出训练样本1的损失为0.479,训练样本2的损失比较高,是1.200。这个结果是有道理的,因为y_hat_softmax显示训练样本1的最高概率是类别2,与正确标签y_true匹配,而训练样本2的最高概率预测为类别1,与实际标签(类别3)不匹配。

loss_per_instance_1 = -tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1])
sess.run(loss_per_instance_1)
# array([ 0.4790107 , 1.19967598])

我们想要的是训练集上的所有的loss,因此我们需要将每个训练样本的loss加起来,如下:

total_loss_1 = tf.reduce_mean(-tf.reduce_sum(y_true * tf.log(y_hat_softmax), reduction_indices=[1]))
sess.run(total_loss_1)
# 0.83934333897877944

使用softmax_cross_entropy_with_logits()

我们也可以使用tf.nn.softmax_cross_entropy_with_logits()函数来计算整个交叉熵损失,代码如下:

loss_per_instance_2 = tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true)
sess.run(loss_per_instance_2)
# array([ 0.4790107 , 1.19967598]) total_loss_2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_hat, y_true))
sess.run(total_loss_2)
# 0.83934333897877922

注意到,total_loss_1和total_loss_2计算得到的结果是基本一样的(有一点点小的差别)。然而,还是建议使用第二种方法,原因是1)代码量更少,2)方法二的内部考虑到了一些边界情况,不容易产生错误。

[翻译] softmax和softmax_cross_entropy_with_logits的区别的更多相关文章

  1. PyTorch学习笔记——softmax和log_softmax的区别、CrossEntropyLoss() 与 NLLLoss() 的区别、log似然代价函数

    1.softmax 函数 Softmax(x) 也是一个 non-linearity, 但它的特殊之处在于它通常是网络中一次操作. 这是因为它接受了一个实数向量并返回一个概率分布.其定义如下. 定义 ...

  2. sigmoid和softmax的应用意义区别

    转载自:https://baijiahao.baidu.com/s?id=1636737136973859154&wfr=spider&for=pc写的很清楚,并举例佐证,容易理解,推 ...

  3. Difference between nn.softmax & softmax_cross_entropy_with_logits & softmax_cross_entropy_with_logits_v2

    nn.softmax 和 softmax_cross_entropy_with_logits 和 softmax_cross_entropy_with_logits_v2 的区别   You have ...

  4. CNKI翻译助手-连接数据库失败

    IP并发数限制,老师说西工大的CNKI才20个并发指标,HPU自不必说.但是我略表怀疑,这只是翻译助手而已,就像百度翻译和百度数据库的区别,如何验证呢?去校外用该助手,如果能用,那么就不是IP并发限制 ...

  5. 你真的了解word-wrap和word-break的区别吗?

    这两个东西是什么,我相信至今还有很多人搞不清,只会死记硬背的写一个word-wrap:break-word;word-break:break-all;这样的东西来强制断句,又或者是因为这两个东西实在是 ...

  6. 你真的了解word-wrap和word-break的区别吗? (转载)

    这两个东西是什么,我相信至今还有很多人搞不清,只会死记硬背的写一个word-wrap:break-word;word-break:break-all;这样的东西来强制断句,又或者是因为这两个东西实在是 ...

  7. Pytorch之CrossEntropyLoss() 与 NLLLoss() 的区别

    (三)PyTorch学习笔记——softmax和log_softmax的区别.CrossEntropyLoss() 与 NLLLoss() 的区别.log似然代价函数 pytorch loss fun ...

  8. word-wrap和word-break的区别吗?

    word-wrap: css的 word-wrap 属性用来标明是否允许浏览器在单词内进行断句,这是为了防止当一个字符串太长而找不到它的自然断句点时产生溢出现象. word-break: css的 w ...

  9. softmax、cross entropy和softmax loss学习笔记

    之前做手写数字识别时,接触到softmax网络,知道其是全连接层,但没有搞清楚它的实现方式,今天学习Alexnet网络,又接触到了softmax,果断仔细研究研究,有了softmax,损失函数自然不可 ...

随机推荐

  1. [POJ 3764] The xor-longest Path

    Description 多组数据 给你一颗树, 然后求一条最长异或路径, 异或路径长度定义为两点间简单路径上所有边权的异或和. Solution 首先 dfs 一遍,求出所有的点到根节点(随便选一个) ...

  2. Quikapp快应用开发入门

    快应诞生背景 微信的小程序使得很多原来需要调动APP的场景不复存在,正式由于微信小程序的冲击,3月20日,华为联手九大手机厂商,共同举办了“快应用”标准启动发布会.“快应用”是几家手机厂商基于硬件平台 ...

  3. Python中四种样式的99乘法表

    1.常规型. #常规型 i=1 while i<=9: j=1 while j<=i: print(''%d*%d=%2d''%(i,j,i*j),end='') i+=1 #等号只是用来 ...

  4. 【Python】 xml解析与生成 xml

    xml *之前用的时候也没想到..其实用BeautifulSoup就可以解析xml啊..因为html只是xml的一种实现方式吧.但是很蛋疼的一点就是,bs不提供获取对象的方法,其find大多获取的都是 ...

  5. [poj2752]Seek the Name, Seek the Fame_KMP

    Seek the Name, Seek the Fame poj-2752 题目大意:给出一个字符串p,求所有既是p的前缀又是p的后缀的所有字符串长度,由小到大输出. 注释:$1\le strlen( ...

  6. 从源码来看ReentrantLock和ReentrantReadWriteLock

    上一篇花了点时间将同步器看了一下,心中对锁的概念更加明确了一点,知道我们所使用到的锁是怎么样获取同步状态的,我们也写了一个自定义同步组件Mutex,讲到了它其实就是一个简版的ReentrantLock ...

  7. 从零部署Spring boot项目到云服务器(准备工作)

    自己的博客终于成功部署上线了,回过头来总结记录一下整个项目的部署过程! 测试地址:47.94.154.205:8084 注:文末有福利! 一.Linux下应用Shell通过SSH连接云服务器 //ss ...

  8. Java基础学习(二)

    软件设计原则: 为了提高软件的开发效率,降低软件开发成本,一个优良的软件系统应该具有以下特点: 1,可重用性:遵循DRY原则,减少软件中的重复代码. 2,可拓展性:当软件需要升级增加新的功能,能够在现 ...

  9. qt中控件的使用函数

    1.Text Edit编辑框 //将编辑框中的内容转化成Utf8编码 ui->textEdit->toPlainText().toUtf8(); 2.Combo Box下拉框的应用 (1) ...

  10. Java之排序

    1.插入排序 假设第一个数已经是排好序的,把第二个根据大小关系插到第一个前面或维持不动,把第三个根据前面两个的大小关系插到对应位置,依次往后. public class InsertSort { pu ...