机器学习的面试题中经常会被问到交叉熵(cross entropy)和最大似然估计(MLE)或者KL散度有什么关系,查了一些资料发现优化这3个东西其实是等价的。

熵和交叉熵

提到交叉熵就需要了解下信息论中熵的定义。信息论认为:

确定的事件没有信息,随机事件包含最多的信息。

事件信息的定义为:\(I(x)=-log(P(x))\);而熵就是描述信息量:\(H(x)=E_{x\sim P}[I(x)]\),也就是\(H(x)=E_{x\sim P}[-log(P(x))]=-\Sigma_xP(x)log(P(x))\)。如果log的base是2,熵可以认为是衡量编码对应的信息需要的最少bits数;那么交叉熵就是来衡量用特定的编码方案Q来对分布为P的信息x进行编码时需要的最少的bits数。定义如下:
\(H(P, Q)=-\Sigma_xP(x)log(Q(x))\)
在深度学习中,P是gt label的真实分布;Q就是网络学习后输出的分布。

最大似然估计

机器学习中,通过最大似然估计方法使参数为\(\hat\Theta\)的模型使预测值贴近真实数据的概率最大化,即\(\hat\Theta=arg\ max_\theta \Pi_{i=1}^Np(x_i|\Theta)\)。实际操作中,连乘很容易出现最大值或最小值溢出,造成计算不稳定,由于log函数的单调性,所以将上式进行取对数取负,最小化负对数似然(NLL)的结果与原始式子是一样的,即\(\hat \Theta =arg\ min_\Theta - \Sigma_{i=1}^Nlog(p(x_i|\Theta))\).

对模型的预测值进行最大似然估计,
\(\hat \Theta =arg\ min_\Theta - \Sigma_{i=1}^Nlog(q(x_i|\Theta))\)
\(=arg\min_\Theta-\Sigma_{x\in X}p(x)log(q(x|\Theta))\)
\(=arg\ min_\Theta H(p, q)\)

所以最小化NLL和最小化交叉熵最后达到的效果是一样的。

KL散度

在深度学习中,KL散度用来评估模型输出的预测值分布与真值分布之间的差异,定义如下:\(D_{KL}(P||Q)=E_xlog(P(x)/Q(x))\)
\(D_{KL}(P||Q)=\Sigma_{x=1}^NP(x)log(P(x)/Q(x))\)
\(=\Sigma_{x=1}^NP(x)[logP(x)-logQ(x)]\)

注意:KL散度不是标准的距离,因为不满足互换性,即\(D_{KL}(P||Q)\neq D_{KL}(Q||P)\)
对于交叉熵:
\(H(P, Q) = -\Sigma PlogQ\)
\(= -\Sigma PlogP+\Sigma PlogP-\Sigma PlogQ\)
\(= H(P) +\Sigma PlogP/Q\)
\(=H(P)+D_{KL}(P||Q)\)

也就是交叉熵就是真值分布的熵与KL散度的和,而真值的熵是确定的,与模型的参数\(\Theta\)无关,所以梯度下降求导时 \(\nabla H(P, Q)=\nabla D_{KL}(P||Q)\),也就是说最小化交叉熵与最小化KL散度是一样的。

总结

从优化模型参数角度来说,最小化交叉熵,NLL,KL散度这3种方式对模型参数的更新来说是一样的。从这点来看也解释了为什么在深度学习中交叉熵是非常常用的损失函数的原因了。

参考:

https://jhui.github.io/2017/01/05/Deep-learning-Information-theory/

深度学习中交叉熵和KL散度和最大似然估计之间的关系的更多相关文章

  1. 信息熵,交叉熵与KL散度

    一.信息熵 若一个离散随机变量 \(X\) 的可能取值为 \(X = \{ x_{1}, x_{2},...,x_{n}\}\),且对应的概率为: \[p(x_{i}) = p(X=x_{i}) \] ...

  2. 【机器学习基础】熵、KL散度、交叉熵

    熵(entropy).KL 散度(Kullback-Leibler (KL) divergence)和交叉熵(cross-entropy)在机器学习的很多地方会用到.比如在决策树模型使用信息增益来选择 ...

  3. 【转载】深度学习中softmax交叉熵损失函数的理解

    深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning   版权声明:本文为博主原 ...

  4. [ML]熵、KL散度、信息增益、互信息-学习笔记

    [ML]熵.KL散度.信息增益.互信息-学习笔记 https://segmentfault.com/a/1190000000641079

  5. 深度学习中正则化技术概述(附Python代码)

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 磐石 介绍 数据科学研究者们最常遇见的问题之一就是怎样避免过拟合. ...

  6. 卷积在深度学习中的作用(转自http://timdettmers.com/2015/03/26/convolution-deep-learning/)

    卷积可能是现在深入学习中最重要的概念.卷积网络和卷积网络将深度学习推向了几乎所有机器学习任务的最前沿.但是,卷积如此强大呢?它是如何工作的?在这篇博客文章中,我将解释卷积并将其与其他概念联系起来,以帮 ...

  7. DDos攻击,使用深度学习中 栈式自编码的算法

    转自:http://www.airghc.top/2016/11/10/Dection-DDos/ 最近研究了一篇论文,关于检测DDos攻击,使用了深度学习中 栈式自编码的算法,现在简要介绍一下内容论 ...

  8. 从极大似然估计的角度理解深度学习中loss函数

    从极大似然估计的角度理解深度学习中loss函数 为了理解这一概念,首先回顾下最大似然估计的概念: 最大似然估计常用于利用已知的样本结果,反推最有可能导致这一结果产生的参数值,往往模型结果已经确定,用于 ...

  9. 深度学习中的Data Augmentation方法(转)基于keras

    在深度学习中,当数据量不够大时候,常常采用下面4中方法: 1. 人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批"新"的数据.也就是Data Augm ...

随机推荐

  1. 学习Struts--Chap05:值栈和OGNL

    1.值栈的介绍 1.1 值栈的介绍: 值栈是对应每一个请求对象的数据存储中心,struts2会给每一个请求对象创建一个值栈,我们大多数情况下不需要考虑值栈在哪里,里面有什么,只需要去获取自己需要的数据 ...

  2. jmeter接口测试实例6-注册(参数化)

    Jmeter实例6:注册(参数化) 选中http协议,添加CSV Data set Config 准备参数中要使用到的值,存放到txt中,如果一个里面有多个参数,中间用,号分隔: 选中CSV元件,fi ...

  3. gravity 和 layout_gravity

    gravity : 是控件内部的内容的对齐方式. layout_gravity: 是控件相对于其容器的对齐方式. 如果 LinearLayout  的  android:orientation=&qu ...

  4. 正则表达式零宽断言详解(?=,?<=,?!,?<!)

    在使用正则表达式时,有时我们需要捕获的内容前后必须是特定内容,但又不捕获这些特定内容的时候,零宽断言就起到作用了 正则表达式零宽断言: 零宽断言是正则表达式中的难点,所以重点从匹配原理方面进行分析.零 ...

  5. hihocoder1524

    题目链接 时间限制:10000ms 单点时限:1000ms 内存限制:256MB 描述 给定一个1-N的排列A1, A2, ... AN,如果Ai和Aj满足i < j且Ai > Aj,我们 ...

  6. python之装饰器篇

    一.基本装饰器 基本装饰器的作用: 在不改变原函数的基础上, 通过装饰器, 给原函数新增某些功能 实现方法: 在原函数上加 @装饰器名字 其中@叫做语法糖 定义装饰器 第一层函数传入参数(用于传入原函 ...

  7. eclipse实现代码块折叠-类似于VS中的#region……#endregion

    背 景 刚才在写代码的时候,写了十几行可以说是重复的代码: 如果整个方法或类中代码多了,感觉它们太TM占地方了,给读者在阅读代码上造成很大的困难,于是想到能不能把他们“浓缩”成一行,脑子里第一个闪现出 ...

  8. double compare 0

    因为double类型或float类型都是有精度的,其实都是取的近似值,所以有个误差.和一个很小的数比如0.00000001(1e-8)比较就是为了在这个误差范围内进行比较. 举个例子如double b ...

  9. Linux DNS 查询剖析(第四部分) | Linux 中国

    版权声明:本文为博主原创文章,未经博主同意不得转载. https://blog.csdn.net/F8qG7f9YD02Pe/article/details/82879414 在第四部分中,我将介绍容 ...

  10. Redis接口的调用

    1.hiredis是redis数据库的C接口,目录为/redis-3.2.6/deps/hiredis 2.示例代码如下: #include <stdio.h> #include < ...