Regularization and model selection

假设我们为了一个学习问题尝试从几个模型中选择一个合适的模型。例如,我们可能用一个多项式回归模型hθ(x)=g(θ0+θ1x+θ2x2+…θkxk),我们需要设定一个合适的阶数k,怎样才能决定这个阶数k,以使得最终模型的bias与variance之间能够达到某种平衡,或者,在locally weighted regression 中,我们如何确定参数τ,以及在ℓ1-regularized 的SVM中,如何确定参数C。

在为某个learning problem选择模型的时候,我们假设存在一个有限的模型集合即,M={M1,...Md}供我们选择,比如对于多项式回归模型,第i个模型Mi可以是一个i阶的多项式回归模型,或者说我们想从SVM,神经网络,logistic regression中选择一个合适的模型,那么模型集合就包括这些模型。下面我们介绍一些模型选择的方法。

Cross validation

第一个方法是cross validation,假设我们有一个训练集,并且给定了 empirical risk minimization (ERM),一个比较直观的方法就是利用ERM作为指标,测试每一个模型在训练集上的ERM,然后选择ERM最小的那个模型。但是实际上,这种方式选择的模型并不好,之前我们已经看到过,如果单纯地依靠ERM,会出现过拟合的问题,所以一个更加可靠的方式是hold-out cross validation,我们通过以下步骤实现模型的选择:

1:将训练集S随机分成Strain(一般来说占训练集的70\%)和Scv(占训练集的30\%)。Scv 称为hold-out cross validation set。

2: 只利用Strain训练每一个模型,并且得到假设决策函数hi。

3: 我们利用Scv来测试这些模型与决策函数,在Scv上误差最小的模型将被选择。

通过这种方式,我们可以利用一些模型未曾见过的数据来测试模型的generalization error,这样对模型的generalization error 可以有更好的估计,通常我们选择训练集的1/4−1/3作为hold out cross validation set,30\% 是一个比较经典的选择。

利用hold out cross validation 的一个缺陷是它“浪费”了30\%的数据,这些数据不能用于训练,如果数据量很大的话,这种“浪费”可以接受,但是如果数据本来就很少的时候,我们需要考虑换一种

方式来选择最优的学习模型,下面介绍的这种方式称为\textbf{k-fold cross validation},这种方式每次剔除出来的数据要少一点。

1:将含有m个训练样本的训练集S,随机分成没有交集的k组,每一组含有m/k个样本,我们称每一组为一个子集,即:S1,...Sk。

2:对于一个学习模型,我们每次剔除一组,剩下的k-1组用来训练这个学习模型,并且得到一个对应的决策函数,用得到的决策函数在事先剔除出来的哪一组上做测试,如此循环k次,我们可以得到这个学习模型在这个训练集上的平均误差。

3:对应每一个学习模型都采取这样的策略训练,最终在训练集上的平均误差最小的模型将被选择。

在实际应用中,k一般设为10,这样每次有1/10的数据被剔除出来做测试,剩下的90\%的数据可以用来训练,这种方式比起上一种方式,可以让更多的数据用来训练,不过这种方式需要花费更多的训练时间,因为每一个学习模型我们都要训练k次。

虽然k=10是一个比较常见的选择,但是有的时候,当数据样本非常稀少的时候,我们也会选择k=m以确保每次剔除的数据尽可能地少,在这种情况下,我们每次剔除一个样本,循环m次,然后用平均误差来估计学习模型的generalization error,这种方式也称为leave-one-out cross validation

虽然我们介绍了这几种方法从模型集合中来选择一个合适的模型,但是这几种方法有的时候也可以直接用来评价一个模型或者算法的性能。而实际应用中,我们也经常

用这几种方法来评价一些算法的性能。

Feature Selection

模型选择的一个特殊而重要的应用是特征选择。想象一下,如果我们遇到一个supervised的学习问题,其输入特征的维数n远远大于样本数m,但是这些特征可能只有一部分与问题是有关系的,在这种情况下,可以设计一个特征选择的算法来降低特征的个数,如果一个训练样本含有n个特征,那么就有2n特征组合,如果n很大的话,那么这样算法的计算量会很大,所以一般不会用这种方法选择特征,一个可替代的方法是forward research

1: 初始化F=∅。

2: Repeat { \

(a) 对于i=1,2,...n,如果i∉F,让Fi=F∪{i},利用前面介绍的cross validation的方法对Fi进行评估。(即用学习模型训练Fi,并且估计它的generalization error。)\

(b) 设定F为步骤(a)中的最佳特征子集。\

}

3: 从整个循环过程中,选择最佳的训练子集。

这个方法给出了wrapper model feature selection的一个实例,因为这是一个不断用特征子集”warps”学习模型的过程,需要不断地调用学习模型以评估各个特征子集的性能,除了forward research,还有一种方法就是backward research,这个也很容易理解,forward research 就是从空集一点一点的增大,直到全部n个特征,而backward research恰恰相反,从全部n个特征开始,一点一点地减少,直到空集。虽然这两种特征选择的方法比较有效,但是非常耗时,通常来说,含有n个特征的训练集,

需要O(n2)次调用学习算法。

Filter feature selection是一种相对来说更加高效的特征选择方法,这种方法的核心思想就是计算一些简单的指标S(i)来衡量特征xi与输出y之间的关联性,然后,挑出与输出之间关联最紧密的k个特征作为特征子集。

一个可能的选择就是将S(i)定义为xi与输出y之间的相关性,实际应用中,我们会定义S(i)为xi与输出y之间的mutual information:

MI(xi,y)=∑x∈{0,1}∑y∈{0,1}p(xi,y)logp(xi,y)p(xi)p(y)

其中,p(xi,y),p(xi),p(y) 可以通过训练集估计得到。

上式也可以表示成Kullback-Leibler (KL) divergence:

MI(xi,y)=KL(p(xi,y)||p(xi)p(y))

KL-divergence 给出了衡量概率分布p(xi,y)与p(xi)p(y)之间的差异的一种方式,如果xi与输出y是相互独立的,那么可以知道:p(xi,y)=p(xi)p(y),进而这两个分布的KL-divergence将会是0,这个从直观上的理解就是,既然xi与输出y是相互独立的,意味着xi与输出y之间没有什么联系,因此S(i)会很小,如果xi与输出y之间有很强的联系,那么MI(xi,y)将会很大。

Bayesian statistics and regularization

下面,我们再介绍一种防止overfitting的方法,之前我们介绍过利用最大似然估计来求参数的方法,我们会建立如下的目标函数:

θML=argmaxθ∏i=1np(y(i)|x(i);θ)

我们接下来的讨论,会假设参数θ是一个未知的常数,而不是一个随机数,这是从\textbf{频率统计}的角度出发,所以我们的任务就是要找到这个未知的常数。而如果从贝叶斯的角度考虑,会认为参数θ是一个未知的随机数,在这种观点下,我们会建立参数θ的一个先验概率分布,p(θ),对参数θ会有一个先验估计,给定一个训练集S={(x(i),y(i))}mi=1,当我们需要对一个新输入的特征向量x做预测时,我们可以先计算参数p(θ)相对于训练集S的后验分布:

p(θ|S)=p(S|θ)p(θ)p(S)=(∏mi=1p(y(i)|x(i),θ))p(θ)∫θ(∏mi=1p(y(i)|x(i),θ)p(θ))dθ

上式中,p(y(i)|x(i),θ)取决于我们所选择的学习模型,比如,如果我们选择Bayesian logistic regression,那么我们可能会选择:

p(y(i)|x(i),θ)=hθ(x(i))y(i)(1−hθ(x(i)))1−y(i),其中,hθ(x(i))=1/(1+exp(−θTx(i))),

当我们需要对一个新输入的样本做预测的时候,我们可以利用参数θ的后验分布计算输出y的后验分布:

p(y|x,S)=∫θp(y|x,θ)p(θ|S)dθ

同样地,我们可以计算输出y关于x的期望:

E[y|x,S]=∫yyp(y|x,S)dy

实际应用中,由于对参数θ的积分非常困难,所以一般我们不会直接用上面的表达式运算参数θ的后验分布,我们会采用估计的办法,用下面的表达式对参数进行估计:

θMAP=argmaxθ∏i=1mp(y(i)|x(i),θ)p(θ)

上式相当于将参数θ的后验分布转化为一个“点估计”,这个表达式称为maximum a posteriori (MAP),可以看到这个表达式和最大似然估计的表达式很像,唯一的区别在于MAP多了一个参数θ的先验分布p(θ),在实际地应用中,我们会将先验分布估计为高斯分布,即θ∼N(0,τ2I),基于这个先验分布,估计得到的参数θ会比最大似然估计得到的参数θ拥有更小的范数,这样会让MAP学习模型比最大似然估计模型对overfitting有更稳健的抗干扰性。

机器学习 Regularization and model selection的更多相关文章

  1. 转:机器学习 规则化和模型选择(Regularization and model selection)

    规则化和模型选择(Regularization and model selection) 转:http://www.cnblogs.com/jerrylead/archive/2011/03/27/1 ...

  2. Bias vs. Variance(2)--regularization and bias/variance,如何选择合适的regularization parameter λ(model selection)

    Linear regression with regularization 当我们的λ很大时,hθ(x)≍θ0,是一条直线,会出现underfit:当我们的λ很小时(=0时),即相当于没有做regul ...

  3. Andrew Ng机器学习公开课笔记 -- Regularization and Model Selection

    网易公开课,第10,11课 notes,http://cs229.stanford.edu/notes/cs229-notes5.pdf   Model Selection 首先需要解决的问题是,模型 ...

  4. Regularization and model selection

    Suppose we are trying select among several different models for a learning problem.For instance, we ...

  5. Scikit-learn:模型选择Model selection

    http://blog.csdn.net/pipisorry/article/details/52250983 选择合适的estimator 通常机器学习最难的一部分是选择合适的estimator,不 ...

  6. 学习笔记之Model selection and evaluation

    学习笔记之scikit-learn - 浩然119 - 博客园 https://www.cnblogs.com/pegasus923/p/9997485.html 3. Model selection ...

  7. 评估预测函数(3)---Model selection(选择多项式的次数) and Train/validation/test sets

    假设我们现在想要知道what degree of polynomial to fit to a data set 或者 应该选择什么features 或者 如何选择regularization par ...

  8. Spark2 Model selection and tuning 模型选择与调优

    Model selection模型选择 ML中的一个重要任务是模型选择,或使用数据为给定任务找到最佳的模型或参数. 这也称为调优. 可以对诸如Logistic回归的单独Estimators进行调整,或 ...

  9. scikit-learn:3. Model selection and evaluation

    參考:http://scikit-learn.org/stable/model_selection.html 有待翻译,敬请期待: 3.1. Cross-validation: evaluating ...

随机推荐

  1. 设置windows时间开机同步方法

    本作品由Man_华创作,采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可.基于http://www.cnblogs.com/manhua/上的作品创作. 适用场景: 主板电池 ...

  2. Nginx 一些常用的URL 重写方法

    url重写应该不陌生,不管是SEO URL 伪静态的需要,还是在非常流行的wordpress里,重写无处不在. 1. 在 Apache 的写法 RewriteCond  %{HTTP_HOST}  n ...

  3. POJ 2155 Matrix(二维树状数组,绝对具体)

    Matrix Time Limit: 3000MS   Memory Limit: 65536K Total Submissions: 20599   Accepted: 7673 Descripti ...

  4. JVM、垃圾回收、内存调优、常见參数

    一.什么是JVM JVM是Java Virtual Machine(Java虚拟机)的缩写.JVM是一种用于计算设备的规范.它是一个虚构出来的计算机,是通过在实际的计算机上仿真模拟各种计算机功能来实现 ...

  5. Android中BaseAdapter使用基础点

    Android中要填充一些控件(如ListView)经常须要用到Adapter来实现,经常使用的有ArrayAdapter,SimpleAdapter, CursorAdapter,BaseAdapt ...

  6. linux uart驱动——uart platfrom 注册(三)

    一:注册platform device 注册一个platfrom device一般需要初始化两个内容,设备占用的资源resource和设备私有数据dev.platfrom_data.设备的resour ...

  7. Oracle学习第一篇—安装和简单语句

    一 安装  10G ----不适合Win7 Visual Machine-++++Visual Hard Disk 先安装介质(VM)---便于删除 11G-----适合Win7 1 把win64_1 ...

  8. POJ2407_Relatives【欧拉phi函数】【基本】

    Relatives Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 11422 Accepted: 5571 Descriptio ...

  9. Trie树,又称单词查找树、字典

    在百度或淘宝搜索时,每输入字符都会出现搜索建议,比如输入“北京”,搜索框下面会以北京为前缀,展示“北京爱情故事”.“北京公交”.“北京医院”等等搜索词.实现这类技术后台所采用的数据结构是什么?[中国某 ...

  10. Java知识点梳理——常用方法总结

    1.查找字符串最后一次出现的位置 String str = "my name is zzw"; int lastIndex = str.lastIndexOf("zzw& ...