【转】TensorFlow四种Cross Entropy算法实现和应用
http://www.jianshu.com/p/75f7e60dae95
作者:陈迪豪 来源:CSDN
http://dataunion.org/26447.html
交叉熵介绍
交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小,常见的Loss函数就是均方平方差(Mean Squared Error),定义如下。
平方差很好理解,预测值与真实值直接相减,为了避免得到负数取绝对值或者平方,再做平均就是均方平方差。注意这里预测值需要经过sigmoid激活函数,得到取值范围在0到1之间的预测值。
平方差可以表达预测值与真实值的差异,但在分类问题种效果并不如交叉熵好,原因可以参考这篇博文 。交叉熵的定义如下,截图来自https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s1.html 。
上面的文章也介绍了交叉熵可以作为Loss函数的原因,首先是交叉熵得到的值一定是正数,其次是预测结果越准确值越小,注意这里用于计算的“a”也是经过sigmoid激活的,取值范围在0到1。如果label是1,预测值也是1的话,前面一项y ln(a)就是1 ln(1)等于0,后一项(1 – y) ln(1 – a)也就是0 ln(0)等于0,Loss函数为0,反之Loss函数为无限大非常符合我们对Loss函数的定义。
这里多次强调sigmoid激活函数,是因为在多目标或者多分类的问题下有些函数是不可用的,而TensorFlow本身也提供了多种交叉熵算法的实现。
TensorFlow的交叉熵函数
TensorFlow针对分类问题,实现了四个交叉熵函数,分别是
tf.nn.sigmoid_cross_entropy_with_logits、tf.nn.softmax_cross_entropy_with_logits、tf.nn.sparse_softmax_cross_entropy_with_logits和tf.nn.weighted_cross_entropy_with_logits,详细内容参考API文档https://www.tensorflow.org/versions/master/api_docs/python/nn.html#sparse_softmax_cross_entropy_with_logits
sigmoid_cross_entropy_with_logits详解
我们先看sigmoid_cross_entropy_with_logits,为什么呢,因为它的实现和前面的交叉熵算法定义是一样的,也是TensorFlow最早实现的交叉熵算法。这个函数的输入是logits和targets,logits就是神经网络模型中的 W X矩阵,注意不需要经过sigmoid,而targets的shape和logits相同,就是正确的label值,例如这个模型一次要判断100张图是否包含10种动物,这两个输入的shape都是[100, 10]。注释中还提到这10个分类之间是独立的、不要求是互斥,这种问题我们成为多目标,例如判断图片中是否包含10种动物,label值可以包含多个1或0个1,还有一种问题是多分类问题,例如我们对年龄特征分为5段,只允许5个值有且只有1个值为1,这种问题可以直接用这个函数吗?答案是不可以,我们先来看看sigmoid_cross_entropy_with_logits的代码实现吧。
可以看到这就是标准的Cross Entropy算法实现,对W X得到的值进行sigmoid激活,保证取值在0到1之间,然后放在交叉熵的函数中计算Loss。对于二分类问题这样做没问题,但对于前面提到的多分类,例如年轻取值范围在0~4,目标值也在0~4,这里如果经过sigmoid后预测值就限制在0到1之间,而且公式中的1 – z就会出现负数,仔细想一下0到4之间还不存在线性关系,如果直接把label值带入计算肯定会有非常大的误差。因此对于多分类问题是不能直接代入的,那其实我们可以灵活变通,把5个年龄段的预测用onehot encoding变成5维的label,训练时当做5个不同的目标来训练即可,但不保证只有一个为1,对于这类问题TensorFlow又提供了基于Softmax的交叉熵函数。
softmax_cross_entropy_with_logits详解
Softmax本身的算法很简单,就是把所有值用e的n次方计算出来,求和后算每个值占的比率,保证总和为1,一般我们可以认为Softmax出来的就是confidence也就是概率,算法实现如下。
softmax_cross_entropy_with_logits和sigmoid_cross_entropy_with_logits很不一样,输入是类似的logits和lables的shape一样,但这里要求分类的结果是互斥的,保证只有一个字段有值,例如CIFAR-10中图片只能分一类而不像前面判断是否包含多类动物。想一下问什么会有这样的限制?在函数头的注释中我们看到,这个函数传入的logits是unscaled的,既不做sigmoid也不做softmax,因为函数实现会在内部更高效得使用softmax,对于任意的输入经过softmax都会变成和为1的概率预测值,这个值就可以代入变形的Cross Entroy算法- y ln(a) – (1 – y) ln(1 – a)算法中,得到有意义的Loss值了。如果是多目标问题,经过softmax就不会得到多个和为1的概率,而且label有多个1也无法计算交叉熵,因此这个函数只适合单目标的二分类或者多分类问题,TensorFlow函数定义如下。
再补充一点,对于多分类问题,例如我们的年龄分为5类,并且人工编码为0、1、2、3、4,因为输出值是5维的特征,因此我们需要人工做onehot encoding分别编码为00001、00010、00100、01000、10000,才可以作为这个函数的输入。理论上我们不做onehot encoding也可以,做成和为1的概率分布也可以,但需要保证是和为1,和不为1的实际含义不明确,TensorFlow的C++代码实现计划检查这些参数,可以提前提醒用户避免误用。
sparse_softmax_cross_entropy_with_logits详解
sparse_softmax_cross_entropy_with_logits是softmax_cross_entropy_with_logits的易用版本,除了输入参数不同,作用和算法实现都是一样的。前面提到softmax_cross_entropy_with_logits的输入必须是类似onehot encoding的多维特征,但CIFAR-10、ImageNet和大部分分类场景都只有一个分类目标,label值都是从0编码的整数,每次转成onehot encoding比较麻烦,有没有更好的方法呢?答案就是用sparse_softmax_cross_entropy_with_logits,它的第一个参数logits和前面一样,shape是[batch_size, num_classes],而第二个参数labels以前也必须是[batch_size, num_classes]否则无法做Cross Entropy,这个函数改为限制更强的[batch_size],而值必须是从0开始编码的int32或int64,而且值范围是[0, num_class),如果我们从1开始编码或者步长大于1,会导致某些label值超过这个范围,代码会直接报错退出。这也很好理解,TensorFlow通过这样的限制才能知道用户传入的3、6或者9对应是哪个class,最后可以在内部高效实现类似的onehot encoding,这只是简化用户的输入而已,如果用户已经做了onehot encoding那可以直接使用不带“sparse”的softmax_cross_entropy_with_logits函数。
weighted_sigmoid_cross_entropy_with_logits详解
weighted_sigmoid_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的拓展版,输入参数和实现和后者差不多,可以多支持一个pos_weight参数,目的是可以增加或者减小正样本在算Cross Entropy时的Loss。实现原理很简单,在传统基于sigmoid的交叉熵算法上,正样本算出的值乘以某个系数接口,算法实现如下。
总结
这就是TensorFlow目前提供的有关Cross Entropy的函数实现,用户需要理解多目标和多分类的场景,根据业务需求(分类目标是否独立和互斥)来选择基于sigmoid或者softmax的实现,如果使用sigmoid目前还支持加权的实现,如果使用softmax我们可以自己做onehot coding或者使用更易用的sparse_softmax_cross_entropy_with_logits函数。
TensorFlow提供的Cross Entropy函数基本cover了多目标和多分类的问题,但如果同时是多目标多分类的场景,肯定是无法使用softmax_cross_entropy_with_logits,如果使用sigmoid_cross_entropy_with_logits我们就把多分类的特征都认为是独立的特征,而实际上他们有且只有一个为1的非独立特征,计算Loss时不如Softmax有效。这里可以预测下,未来TensorFlow社区将会实现更多的op解决类似的问题,我们也期待更多人参与TensorFlow贡献算法和代码 !
作者:贰拾贰画生
链接:http://www.jianshu.com/p/75f7e60dae95
來源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
【转】TensorFlow四种Cross Entropy算法实现和应用的更多相关文章
- php四种基础排序算法的运行时间比较
/** * php四种基础排序算法的运行时间比较 * @authors Jesse (jesse152@163.com) * @date 2016-08-11 07:12:14 */ //冒泡排序法 ...
- PHP四种基本排序算法
PHP的四种基本排序算法为:冒泡排序.插入排序.选择排序和快速排序. 下面是我整理出来的算法代码: 1. 冒泡排序: 思路:对数组进行多轮冒泡,每一轮对数组中的元素两两比较,调整位置,冒出一个最大的数 ...
- php四种基础排序算法的运行时间比较!
/** * php四种基础排序算法的运行时间比较 * @authors Jesse (jesse152@163.com) * @date 2016-08-11 07:12:14 */ //冒泡排序法 ...
- TCP快速重传与快速恢复原理分析(四种不同的算法)
在TCP/IP中,快速重传和恢复(fast retransmit and recovery,FRR)是一种拥塞控制算法,它能快速恢复丢失的数据包.没有FRR,如果数据包丢失了,TCP将会使用定时器来要 ...
- JVM03——四种垃圾回收算法,你都了解了哪几种
在之前的文章中,已经为各位带来了JVM的内存结构与堆内存的相关介绍,今天将为为各位详解JVM垃圾回收与算法.关注我的公众号「Java面典」了解更多 Java 相关知识点. 如何确定垃圾 想要回收垃圾, ...
- PHP实现四种基本排序算法
前提:分别用冒泡排序法,快速排序法,选择排序法,插入排序法将下面数组中的值按照从小到大的顺序进行排序. $arr(1,43,54,62,21,66,32,78,36,76,39); 1. 冒泡排序 思 ...
- PHP 四种基本排序算法的代码实现
前提:分别用冒泡排序法,快速排序法,选择排序法,插入排序法将下面数组中的值按照从小到大的顺序进行排序. $arr(1,43,54,62,21,66,32,78,36,76,39); 1. 冒泡排序 思 ...
- PHP——四种基本排序算法
分别用冒泡排序法,快速排序法,选择排序法,插入排序法将下面数组中的值按照从小到大的顺序进行排序. $arr(1,43,54,62,21,66,32,78,36,76,39); 1. 冒泡排序 思路分析 ...
- php 四种基础的算法 ---- 冒泡排序法
1. 冒泡排序法 * 思路分析:法如其名,就是像冒泡一样,每次从数组当中 冒一个最大的数出来. * 比如:2,4,1 // 第一次 冒出的泡是4 * ...
随机推荐
- 在windows下安装配置python开发环境及Ulipad开发工具(转)
最近开始学习Python,在网上寻找一下比较好的IDE.因为以前用C#做开发的,用Visual Studio作为IDE,鉴于用惯了VS这么强大的IDE,所以对IDE有一定的依赖性. Python的ID ...
- STM32 F4 DAC DMA Waveform Generator
STM32 F4 DAC DMA Waveform Generator Goal: generating an arbitrary periodic waveform using a DAC with ...
- [Go] 单元测试/性能测试 (go test)
特征 Golang 单元测试对文件名和方法名,参数都有很严格的要求.例如: 1.文件名必须以 _test.go 结尾 2.方法名必须是 Test 开头 3.方法参数必须是 t *testing.T 或 ...
- 在树莓派2上安装 Windows 10
微软在2015年4月29日发布了树莓派玩家期待已久的 Windows 10 物联网核心预览版(Windows 10 IoT Core Insider Preview Image for Raspber ...
- EBS查询默认应用用户,比如是否需要锁定、修改这些用户
/* Formatted on 2018/3/15 13:05:39 (QP5 v5.256.13226.35538) */ --查询默认应用用户,比如是否需要锁定.修改这些用户 SELECT ROW ...
- 平台升级至spring 4.3.0 运行稳定
deprecated velocity 部份方法 整理中...
- 《Android编程权威指南》
<Android编程权威指南> 基本信息 原书名:Android programming: the big nerd ranch guide 原出版社: Big Nerd Ranch Gu ...
- Java生成8位随机邀请码,不重复
public static String[] chars = new String[] { "a", "b", "c", "d&q ...
- Orchard模块开发全接触4:深度改造前台
这里,我们需要做一些事情,这些事情意味着深度改造前台: 1:为商品增加 添加到购物车 按钮,点击后功能实现: 2:商品排序: 3:购物车预览,以及添加 结算 按钮: 4:一个显式 购物车中有*个 商品 ...
- main函数的参数argc和argv
版权声明:本文为博主原创文章,转载请注明CSDN博客源地址!共同学习,一起进步~ https://blog.csdn.net/Eastmount/article/details/20413773 该篇 ...