注:
文章中所有的图片均来自台湾大学林轩田《机器学习基石》课程。
笔记原作者:红色石头
微信公众号:AI有道

上一节课,我们介绍了Logistic Regression问题,建立cross-entropy error,并提出使用梯度下降算法gradient descent来获得最好的logistic hypothesis。本节课继续介绍使用线性模型来解决分类问题。

一、Linear Models for Binary Classification

之前介绍的几种线性模型都有一个共同点,就是都有样本特征\(x\)的加权运算,我们引入一个线性得分函数\(s\):\[s=w^Tx\]三种线性模型,第一种是linear classification。线性分类模型的hypothesis为\(h(x)=sign(s)\),取值范围为\({-1,+1}\)两个值,它的err是\(0/1\)的,所以对应的\(E_{in}(w)\)是离散的,并不好解,这是个NP-hard问题。第二种是linear regression。线性回归模型的hypothesis为\(h(x)=s\),取值范围为整个实数空间,它的err是squared的,所以对应的\(E_{in}(w)\)是开口向上的二次曲线,其解是closed-form的,直接用线性最小二乘法求解即可。第三种是logistic regression。逻辑回归模型的hypothesis为\(h(x)=\theta(s)\),取值范围为\((0,1)\)之间,它的err是cross-entropy的,对应的\(E_{in}(w)\)是平滑的凸函数,可以使用梯度下降算法求最小值。

从上图中,我们发现,linear regression和logistic regression的error function都有最小解。那么可不可以用这两种方法来求解linear classification问题呢?下面,我们来对这三种模型的error function进行分析,看看它们之间有什么联系。
对于linear classification,它的error function可以写成:\[err_{0/1}(s,y)=[ sign(s)\neq y] =[ sign(ys)\neq 1]\]
对于linear regression,它的error function可以写成:\[err_{SQR}(s,y)=(s-y)^2=(ys-1)^2\]
对于logistic regression,它的error function可以写成:\[err_{CE}(s,y)=ln(1+exp(-ys))\]
上述三种模型的error function都引入了\(ys\)变量,那么\(ys\)的物理意义是什么?\(ys\)就是指分类的正确率得分,其值越大越好,得分越高。

下面,用图形化的方式来解释三种模型的error function到底有什么关系:

从上图中可以看出,\(ys\)是横坐标轴, \(err_{0/1}\)是呈阶梯状的,在\(ys>0\)时, 恒取最小值\(0\)。\(err_{SQR}\)呈抛物线形式,在\(ys=1\)时,取得最小值,且在\(ys=1\)左右很小区域内,\(err_{0/1}\)和\(err_{SQR}\)近似。\(err_{CE}\)是呈指数下降的单调函数,\(ys\)越大,其值越小。同样在\(ys=1\)左右很小区域内, \(err_{0/1}\)和\(err_{CE}\)近似。但是我们发现\(err_{CE}\)并不是始终在\(err_{0/1}\)之上,所以为了计算讨论方便,我们把\(err_{CE}\)做幅值上的调整,引入\(err_{SCE}=log_2(1+exp(-ys))=\frac{1}{ln2}err_{CE}\),这样能保证\(err_{SCE}\)始终在\(err_{0/1}\)上面,如下图所示:

由上图可以看出:\[err_{0/1}(s,y)\leq err_{SCE}(s,y)=\frac{1}{ln2}err_{CE}(s,y)\] \[E^{0/1}_{in}(w)\leq E^{SCE}_{in}(w)=\frac{1}{ln2}E^{CE}_{in}(w)\] \[E^{0/1}_{out}(w)\leq E^{SCE}_{out}(w)=\frac{1}{ln2}E^{CE}_{out}(w)\]
那么由VC理论可以知道:
从\(0/1\)出发:\[E^{0/1}_{out}(w)\leq E^{0/1}_{in}(w)+\Omega^{0/1}\leq \frac{1}{ln2}E^{CE}_{in}(w)+\Omega^{0/1}\]
从CE出发:\[E^{0/1}_{out}(w)\leq \frac{1}{ln2}E^{CE}_{out}(w)\leq \frac{1}{ln2}E^{CE}_{in}(w)+\frac{1}{ln2} \Omega^{CE}\]

通过上面的分析,我们看到err 0/1是被限定在一个上界中。这个上界是由logistic regression模型的error function决定的。而linear regression其实也是linear classification的一个upper bound,只是随着\(ys\)偏离1的位置越来越远,linear regression的error function偏差越来越大。综上所述,linear regression和logistic regression都可以用来解决linear classification的问题。

下图列举了PLA、linear regression、logistic regression模型用来解linear classification问题的优点和缺点。通常,我们使用linear regression来获得初始化的\(w_0\),再用logistic regression模型进行最优化解。

二、Stochastic Gradient Descent

之前介绍的PLA算法和logistic regression算法,都是用到了迭代操作。PLA每次迭代只会更新一个点,它每次迭代的时间复杂度是\(O(1)\);而logistic regression每次迭代要对所有\(N\)个点都进行计算,它每次迭代的时间复杂度是\(O(N)\)。为了提高logistic regression中
gradient descent算法的速度,可以使用另一种算法:随机梯度下降算法(StochasticGradient Descent)。

随机梯度下降算法每次迭代只找到一个点,计算该点的梯度,作为我们下一步更新\(w\)的依据。这样就保证了每次迭代的计算量大大减小,我们可以把整体的梯度看成这个随机过程的一个期望值。

随机梯度下降可以看成是真实的梯度加上均值为零的随机噪声方向。单次迭代看,好像会对每一步找到正确梯度方向有影响,但是整体期望值上看,与真实梯度的方向没有差太多,同样能找到最小值位置。随机梯度下降的优点是减少计算量,提高运算速度,而且便于online学习;缺点是不够稳定,每次迭代并不能保证按照正确的方向前进,而且达到最小值需要迭代的次数比梯度下降算法一般要多。

对于logistic regression的SGD,它的表达式为:\[w_{t+1} = w_t + \eta \theta(-y_nw^T_tx_n)(y_nx_n)\]
我们发现,SGD与PLA的迭代公式有类似的地方,如下图所示:

把SDG logistic regression 称之为'soft' PLA, 因为PLA只对分类错误的点进行修正,而SGD logistic regression每次迭代都会进行或多或少的修正。
另外,当\(\eta=1\),且\(w^Tx_n\)足够大的时候,PLA近似等于SGD。(\(y_n=+1\),\(w^Tx_n\)足够大时,PLA: \(w_{t+1}=w_t+0*y_nx_n\), SDG: \(w_{t+1} = w_t + \eta * 0 *(y_nx_n)\) ; \(y_n=-1\),\(w^Tx_n\)足够大时,PLA: \(w_{t+1}=w_t+1*y_nx_n\), SDG: \(w_{t+1} = w_t + \eta * 1 *(y_nx_n)\))。

除此之外,还有两点需要说明:1、SGD的终止迭代条件。没有统一的终止条件,一般让迭代次数足够多;2、学习速率\(\eta\)。\(\eta\)的取值是根据实际情况来定的,一般取值\(0.1\)就可以了。

三、Multiclass via Logistic Regression

之前一直讲的都是二分类问题,本节主要介绍多分类问题,通过linear classification来解决。假设平面上有四个类,分别是正方形、菱形、三角形和星形,如何进行分类模型的训练呢?
首先我们可以想到这样一个办法,就是先把正方形作为正类,其他三种形状都是负类,即把它当成一个二分类问题,通过linear classification模型进行训练,得出平面上某个图形是不是正方形,且只有\({-1,+1}\)两种情况。然后再分别以菱形、三角形、星形为正类,进行二元分类。这样进行四次二分类之后,就完成了这个多分类问题。

但是,这样的二分类会带来一些问题,因为我们只用\({-1,+1}\)两个值来标记,那么平面上可能某些区域都被上述四次二分类模型判断为负类,即不属于四类中的任何一类;也可能会出现某些区域同时被两个类甚至多个类同时判断为正类,比如某个区域又被判定为正方形又被判定为菱形。那么对于这种情况,我们就无法进行多类别的准确判断,所以对于多类别,简单的binary classification不能解决问题。

针对这种问题,我们可以使用另外一种方法来解决:soft软性分类,即不用\({-1,+1}\)这种binary classification,而是使用logistic regression,计算某点属于某类的概率、可能性,取概率最大的值为那一类就好。

soft classification的处理过程和之前类似,同样是分别令某类为正,其他三类为负,不同的是得到的是概率值,而不是\({-1,1}\)。最后得到某点分别属于四类的概率,取最大概率对应的哪一个类别就好。效果如下图所示:

这种多分类的处理方式,我们称之为One-Versus-All(OVA) Decomposition。这种方法的优点是简单高效,可以使用logistic regression模型来解决;缺点是如果数据类别很多时,那么每次二分类问题中,正类和负类的数量差别就很大,数据不平衡unbalanced,这样会影响分类效果。但是,OVA还是非常常用的一种多分类算法。

四、Multiclass via Binary Classification

上一节介绍了多分类算法OVA,但是这种方法存在一个问题,就是当类别\(k\)很多的时候,造成正负类数据unbalanced,会影响分类效果,表现不好。现在,我们介绍另一种方法来解决当\(k\)很大时,OVA带来的问题。

这种方法每次只取两类进行binary classification,取值为\({-1,+1}\)。假如\(k=4\),那么总共需要进行\(C^2_4=6\)次binary classification。那么,六次分类之后,如果平面有个点,有三个分类器判断它是正方形,一个分类器判断是菱形,另外两个判断是三角形,那么取最多的那个,即判断它属于正方形,我们的分类就完成了。这种形式就如同\(k\)个足球对进行单循环的比赛,每场比赛都有一个队赢,一个队输,赢了得1分,输了得0分。那么总共进行了\(C^2_k\)次的比赛,最终取得分最高的那个队就可以了。

这种区别于OVA的多分类方法叫做One-Versus-One(OVO)。这种方法的优点是更加高效,因为虽然需要进行的分类次数增加了,但是每次只需要进行两个类别的比较,也就是说单次分类的数量减少了。而且一般不会出现数据unbalanced的情况。缺点是需要分类的次数多,时间复杂度和空间复杂度可能都比较高。

五、总结

本节课主要介绍了分类问题的三种线性模型:linear classification、linear regression和logistic regression。首先介绍了这三种linear models都可以来做binary classification。然后介绍了比梯度下降算法更加高效的SGD算法来进行logistic regression分析。最后讲解了两种多分类方法,一种是OVA,另一种是OVO。这两种方法各有优缺点,当类别数量\(k\)不多的时候,建议选择OVA,以减少分类次数。

机器学习基石11-Linear Models for Classification的更多相关文章

  1. 《机器学习基石》---Linear Models for Classification

    1 用回归来做分类 到目前为止,我们学习了线性分类,线性回归,逻辑回归这三种模型.以下是它们的pointwise损失函数对比(为了更容易对比,都把它们写作s和y的函数,s是wTx,表示线性打分的分数) ...

  2. 机器学习基石笔记:11 Linear Models for Classification

    一.二元分类的线性模型 线性分类.线性回归.逻辑回归: 可视化这三个线性模型的代价函数, SQR.SCE的值都是大于等于0/1的. 理论分析上界: 将回归应用于分类: 线性回归后的参数值常用于pla/ ...

  3. 机器学习基石笔记:11 Linear Models for Classification、LC vs LinReg vs LogReg、OVA、OVO

    原文地址:https://www.jianshu.com/p/6f86290e70f9 一.二元分类的线性模型 线性回归后的参数值常用于PLA/PA/Logistic Regression的参数初始化 ...

  4. 11 Linear Models for Classification

    一.二元分类的线性模型 线性分类.线性回归.逻辑回归 可视化这三个线性模型的代价函数 SQR.SCE的值都是大于等于0/1的 理论分析上界 将回归应用于分类 线性回归后的参数值常用于pla/pa/lo ...

  5. Coursera台大机器学习课程笔记10 -- Linear Models for Classification

    这一节讲线性模型,先将几种线性模型进行了对比,通过转换误差函数来将linear regression 和logistic regression 用于分类. 比较重要的是这种图,它解释了为何可以用Lin ...

  6. PRML读书会第四章 Linear Models for Classification(贝叶斯marginalization、Fisher线性判别、感知机、概率生成和判别模型、逻辑回归)

    主讲人 planktonli planktonli(1027753147) 19:52:28 现在我们就开始讲第四章,第四章的内容是关于 线性分类模型,主要内容有四点:1) Fisher准则的分类,以 ...

  7. Regression:Generalized Linear Models

    作者:桂. 时间:2017-05-22  15:28:43 链接:http://www.cnblogs.com/xingshansi/p/6890048.html 前言 本文主要是线性回归模型,包括: ...

  8. Generalized Linear Models

    作者:桂. 时间:2017-05-22  15:28:43 链接:http://www.cnblogs.com/xingshansi/p/6890048.html 前言 主要记录python工具包:s ...

  9. [Scikit-learn] 1.5 Generalized Linear Models - SGD for Classification

    NB: 因为softmax,NN看上去是分类,其实是拟合(回归),拟合最大似然. 多分类参见:[Scikit-learn] 1.1 Generalized Linear Models - Logist ...

随机推荐

  1. postman Installation has failed: There was an error while installing the application. Check the setup log for more information and contact the author

    Error msg: Installation has failed: There was an error while installing the application. Check the s ...

  2. An interesting combinational problem

    A question of details in the solution at the end of this post of the question is asked by me at MSE. ...

  3. Docker 核心技术之容器

    什么是容器 容器(Container) 容器是一种轻量级.可移植.并将应用程序进行的打包的技术,使应用程序可以在几乎任何地方以相同的方式运行 Docker将镜像文件运行起来后,产生的对象就是容器.容器 ...

  4. ES6 Promise 用法讲解

    Promise是一个构造函数,自己身上有all.reject.resolve这几个眼熟的方法,原型上有then.catch等同样很眼熟的方法. 那就new一个 var p = new Promise( ...

  5. Flutter获取点击元素的位置与大小

    使用 WidgetsBindingObserver获取 class CloseTap extends StatefulWidget { @override _CloseTapTapState crea ...

  6. Webstorm的一些常用快捷键

    ctrl+/ 单行注释ctrl+shift+/块注释Ctrl+X 删除行Ctrl+D 复制行Ctrl+B 快速打开光标处的类或方法Ctrl+F 查找文本Ctrl+R 替换文本ctrl+shift+ + ...

  7. jmeter 安装

    3.1 windows10环境下测试工具jmeter安装与配置 3.1.1下载安装java 浏览器中打开链接:http://down-www.7down.net/pcdown/soft/xiazai/ ...

  8. Bootstrap 模态框(Modal)插件id冲突

    <!DOCTYPE html><html><head>    <meta charset="utf-8">     <titl ...

  9. jmeter笔记(6)--参数化--函数助手

    函数助手提供的功能很多,本次笔记主要整理_CSVRead 函数._Random函数以及_RandomString函数的基础使用方法 1._CSVRead 作用:直接读取csv文件的值生成函数 1.在[ ...

  10. 百度地图--JS版

    百度地图JS版本 ----选择关键字地图展示对应地址---- CSS body, html { width: %; height: %; margin: ; font-family: "微软 ...