系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI
点击star加星不要吝啬,星越多笔者越努力。

3.1 均方差函数

MSE - Mean Square Error。

该函数就是最直观的一个损失函数了,计算预测值和真实值之间的欧式距离。预测值和真实值越接近,两者的均方差就越小。

均方差函数常用于线性回归(linear regression),即函数拟合(function fitting)。公式如下:

\[
loss = {1 \over 2}(z-y)^2 \tag{单样本}
\]

\[
J=\frac{1}{2m} \sum_{i=1}^m (z_i-y_i)^2 \tag{多样本}
\]

3.1.1 工作原理

要想得到预测值a与真实值y的差距,最朴素的想法就是用\(Error=a_i-y_i\)。

对于单个样本来说,这样做没问题,但是多个样本累计时,\(a_i-y_i\)有可能有正有负,误差求和时就会导致相互抵消,从而失去价值。所以有了绝对值差的想法,即\(Error=|a_i-y_i|\)。这看上去很简单,并且也很理想,那为什么还要引入均方差损失函数呢?两种损失函数的比较如表3-1所示。

表3-1 绝对值损失函数与均方差损失函数的比较

样本标签值 样本预测值 绝对值损失函数 均方差损失函数
\([1,1,1]\) \([1,2,3]\) \((1-1)+(2-1)+(3-1)=3\) \((1-1)^2+(2-1)^2+(3-1)^2=5\)
\([1,1,1]\) \([1,3,3]\) \((1-1)+(3-1)+(3-1)=4\) \((1-1)^2+(3-1)^2+(3-1)^2=8\)
\(4/3=1.33\) \(8/5=1.6\)

可以看到5比3已经大了很多,8比4大了一倍,而8比5也放大了某个样本的局部损失对全局带来的影响,用术语说,就是“对某些偏离大的样本比较敏感”,从而引起监督训练过程的足够重视,以便回传误差。

3.1.2 实际案例

假设有一组数据如图3-3,我们想找到一条拟合的直线。

图3-3 平面上的样本数据

图3-4中,前三张显示了一个逐渐找到最佳拟合直线的过程。

  • 第一张,用均方差函数计算得到Loss=0.53;
  • 第二张,直线向上平移一些,误差计算Loss=0.16,比图一的误差小很多;
  • 第三张,又向上平移了一些,误差计算Loss=0.048,此后还可以继续尝试平移(改变b值)或者变换角度(改变w值),得到更小的损失函数值;
  • 第四张,偏离了最佳位置,误差值Loss=0.18,这种情况,算法会让尝试方向反向向下。

图3-4 损失函数值与直线位置的关系

第三张图损失函数值最小的情况。比较第二张和第四张图,由于均方差的损失函数值都是正值,如何判断是向上移动还是向下移动呢?

在实际的训练过程中,是没有必要计算损失函数值的,因为损失函数值会体现在反向传播的过程中。我们来看看均方差函数的导数:

\[
\frac{\partial{J}}{\partial{a_i}} = a_i-y_i
\]

虽然\((a_i-y_i)^2\)永远是正数,但是\(a_i-y_i\)却可以是正数(直线在点下方时)或者负数(直线在点上方时),这个正数或者负数被反向传播回到前面的计算过程中,就会引导训练过程朝正确的方向尝试。

在上面的例子中,我们有两个变量,一个w,一个b,这两个值的变化都会影响最终的损失函数值的。

我们假设该拟合直线的方程是y=2x+3,当我们固定w=2,把b值从2到4变化时,看看损失函数值的变化如图3-5所示。

图3-5 固定W时,b的变化造成的损失值

我们假设该拟合直线的方程是y=2x+3,当我们固定b=3,把w值从1到3变化时,看看损失函数值的变化如图3-6所示。

图3-6 固定b时,W的变化造成的损失值

3.1.3 损失函数的可视化

损失函数值的3D示意图

横坐标为W,纵坐标为b,针对每一个w和一个b的组合计算出一个损失函数值,用三维图的高度来表示这个损失函数值。下图中的底部并非一个平面,而是一个有些下凹的曲面,只不过曲率较小,如图3-7。

图3-7 W和b同时变化时的损失值形成的曲面

损失函数值的2D示意图

在平面地图中,我们经常会看到用等高线的方式来表示海拔高度值,下图就是上图在平面上的投影,即损失函数值的等高线图,如图3-8所示。

图3-8 损失函数的等高线图

如果还不能理解的话,我们用最笨的方法来画一张图,代码如下:

    s = 200
    W = np.linspace(w-2,w+2,s)
    B = np.linspace(b-2,b+2,s)
    LOSS = np.zeros((s,s))
    for i in range(len(W)):
        for j in range(len(B)):
            z = W[i] * x + B[j]
            loss = CostFunction(x,y,z,m)
            LOSS[i,j] = round(loss, 2)

上述代码针对每个w和b的组合计算出了一个损失值,保留小数点后2位,放在LOSS矩阵中,如下所示:

[[4.69 4.63 4.57 ... 0.72 0.74 0.76]
 [4.66 4.6  4.54 ... 0.73 0.75 0.77]
 [4.62 4.56 4.5  ... 0.73 0.75 0.77]
 ...
 [0.7  0.68 0.66 ... 4.57 4.63 4.69]
 [0.69 0.67 0.65 ... 4.6  4.66 4.72]
 [0.68 0.66 0.64 ... 4.63 4.69 4.75]]

然后遍历矩阵中的损失函数值,在具有相同值的位置上绘制相同颜色的点,比如,把所有值为0.72的点绘制成红色,把所有值为0.75的点绘制成蓝色......,这样就可以得到图3-9。

图3-9 用笨办法绘制等高线图

此图和等高线图的表达方式等价,但由于等高线图比较简明清晰,所以以后我们都使用等高线图来说明问题。

代码位置

ch03, Level1

[ch03-01] 均方差损失函数的更多相关文章

  1. TensorFlow2.0(8):误差计算——损失函数总结

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  2. [ch03-00] 损失函数

    系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力. 第3章 损失函数 3.0 损失函数概论 3.0.1 概念 ...

  3. 逻辑回归损失函数(cost function)

    逻辑回归模型预估的是样本属于某个分类的概率,其损失函数(Cost Function)可以像线型回归那样,以均方差来表示:也可以用对数.概率等方法.损失函数本质上是衡量”模型预估值“到“实际值”的距离, ...

  4. 深度神经网络(DNN)损失函数和激活函数的选择

    在深度神经网络(DNN)反向传播算法(BP)中,我们对DNN的前向反向传播算法的使用做了总结.里面使用的损失函数是均方差,而激活函数是Sigmoid.实际上DNN可以使用的损失函数和激活函数不少.这些 ...

  5. [ch03-02] 交叉熵损失函数

    系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力. 3.2 交叉熵损失函数 交叉熵(Cross Entrop ...

  6. 反向传播算法-损失函数&激活函数

    在监督学习中,传统的机器学习算法优化过程是采用一个合适的损失函数度量训练样本输出损失,对损失函数进行优化求最小化的极值,相应一系列线性系数矩阵W,偏置向量b即为我们的最终结果.在DNN中,损失函数优化 ...

  7. 强化学习(十六) 深度确定性策略梯度(DDPG)

    在强化学习(十五) A3C中,我们讨论了使用多线程的方法来解决Actor-Critic难收敛的问题,今天我们不使用多线程,而是使用和DDQN类似的方法:即经验回放和双网络的方法来改进Actor-Cri ...

  8. 使用CNN做数字识别和人脸识别

    上次写的一层神经网络也都贴这里了. 我有点困,我先睡觉,完了我再修改 这个代码写法不太符合工业代码的规范,仅仅是用来学习的的.还望各位见谅 import sys,ossys.path.append(o ...

  9. Tensorflow之多元线性回归问题(以波士顿房价预测为例)

    一.根据波士顿房价信息进行预测,多元线性回归+特征数据归一化 #读取数据 %matplotlib notebook import tensorflow as tf import matplotlib. ...

随机推荐

  1. vue 父子组件通信详解

    这是一篇详细讲解vue父子组件之间通信的文章,初始学习vue的时候,总是搞不清楚几个情况 通过props在父子组件传值时,v-bind:data="data",props接收的到底 ...

  2. httprunner-2-linux下搭建hrun(下)

    前言 前面我们说了linux下安装python3,hrun是需要依赖数据库,我们用docker进行安装mysql5.7让数据库能正常连接.安装mysql5.7请参考:https://www.cnblo ...

  3. malloc面试题目(转) - [C++]

    试题4: void GetMemory( char *p ){ p = (char *) malloc( 100 );} void Test( void ) { char *str = NULL; G ...

  4. DRF之注册器、响应器、分页器

    一.url注册器 通过DRF的视图组件,数据接口逻辑被我们优化到最剩下一个类,接下来,我们使用DRF的url控制器来帮助我们自动生成url,使用步骤如下: 第一步:导入模块 1 from rest_f ...

  5. 路由器配置深入浅出—路由器接口PPP协议封装及PAP和CHAP验证配置

    知识域: 是针对点对点专线连接的接口的二层封装协议配置 PPP的PAP和CHAP验证,cpt支持,不一定要在gns3上做实验. 路由器出厂默认是hdlc封装,修改为ppp封装后,可以采用pap验证或者 ...

  6. List、Set集合系列之剖析HashSet存储原理(HashMap底层)

    目录 List接口 1.1 List接口介绍 1.2 List接口中常用方法 List的子类 2.1 ArrayList集合 2.2 LinkedList集合 Set接口 3.1 Set接口介绍 Se ...

  7. java 项目时间和服务器时间不一致

    今天线上项目关于时间的几个任务都出了问题,查看日志发现日志的时间不对,用的是log4j,日志输出的时间都早了很长时间. 1 首先先登上服务器查看了服务器的系统时间 linux下 date命令 时间正确 ...

  8. 第三十一章 System V信号量(二)

    用信号量实现进程互斥示例 #include <unistd.h> #include <sys/types.h> #include <stdlib.h> #inclu ...

  9. Web for pentester_writeup之Code injection篇

    Web for pentester_writeup之Code injection篇 Code injection(代码注入) Example 1 <1> name=hacker' 添加一个 ...

  10. 【java基础】为什么重写toString()方法?

    不得不说,有很多java初学者写java实体类的时候,并没有真正理解重写toString() 方法,可能是口头知道也可能是跟风随带添加toString() 方法,并没有真正理解其意义,如果真要被问起来 ...