机器学习算法中如何选取超参数:学习速率、正则项系数、minibatch size

本文是《Neural networks and deep learning》概览 中第三章的一部分,讲机器学习算法中,如何选取初始的超参数的值。(本文会不断补充)



学习速率(learning rate,η)

运用梯度下降算法进行优化时,权重的更新规则中,在梯度项前会乘以一个系数,这个系数就叫学习速率η。下面讨论在训练时选取η的策略。

  • 固定的学习速率。如果学习速率太小,则会使收敛过慢,如果学习速率太大,则会导致代价函数振荡,如下图所示。就下图来说,一个比较好的策略是先将学习速率设置为0.25,然后在训练到第20个Epoch时,学习速率改为0.025。

关于为什么学习速率太大时会振荡,看看这张图就知道了,绿色的球和箭头代表当前所处的位置,以及梯度的方向,学习速率越大,那么往箭头方向前进得越多,如果太大则会导致直接跨过谷底到达另一端,所谓“步子太大,迈过山谷”。

在实践中,怎么粗略地确定一个比较好的学习速率呢?好像也只能通过尝试。你可以先把学习速率设置为0.01,然后观察training cost的走向,如果cost在减小,那你可以逐步地调大学习速率,试试0.1,1.0….如果cost在增大,那就得减小学习速率,试试0.001,0.0001….经过一番尝试之后,你可以大概确定学习速率的合适的值。



Early Stopping

所谓early stopping,即在每一个epoch结束时(一个epoch即对所有训练数据的一轮遍历)计算 validation data的accuracy,当accuracy不再提高时,就停止训练。这是很自然的做法,因为accuracy不再提高了,训练下去也没用。另外,这样做还能防止overfitting。

那么,怎么样才算是validation accuracy不再提高呢?并不是说validation accuracy一降下来,它就是“不再提高”,因为可能经过这个epoch后,accuracy降低了,但是随后的epoch又让accuracy升上去了,所以不能根据一两次的连续降低就判断“不再提高”。正确的做法是,在训练的过程中,记录最佳的validation accuracy,当连续10次epoch(或者更多次)没达到最佳accuracy时,你可以认为“不再提高”,此时使用early stopping。这个策略就叫“ no-improvement-in-n”,n即epoch的次数,可以根据实际情况取10、20、30….



可变的学习速率

在前面我们讲了怎么寻找比较好的learning rate,方法就是不断尝试。在一开始的时候,我们可以将其设大一点,这样就可以使weights快一点发生改变,从而让你看出cost曲线的走向(上升or下降),进一步地你就可以决定增大还是减小learning rate。

但是问题是,找出这个合适的learning rate之后,我们前面的做法是在训练这个网络的整个过程都使用这个learning rate。这显然不是好的方法,在优化的过程中,learning rate应该是逐步减小的,越接近“山谷”的时候,迈的“步伐”应该越小。

在讲前面那张cost曲线图时,我们说可以先将learning rate设置为0.25,到了第20个epoch时候设置为0.025。这是人工的调节,而且是在画出那张cost曲线图之后做出的决策。能不能让程序在训练过程中自动地决定在哪个时候减小learning rate?

答案是肯定的,而且做法很多。一个简单有效的做法就是,当validation accuracy满足 no-improvement-in-n规则时,本来我们是要early stopping的,但是我们可以不stop,而是让learning rate减半,之后让程序继续跑。下一次validation accuracy又满足no-improvement-in-n规则时,我们同样再将learning rate减半(此时变为原始learni rate的四分之一)…继续这个过程,直到learning rate变为原来的1/1024再终止程序。(1/1024还是1/512还是其他可以根据实际确定)。【PS:也可以选择每一次将learning rate除以10,而不是除以2.】

A readable recent paper which demonstrates the benefits of variable learning rates in attacking MNIST.《Deep Big Simple Neural Nets Excel on HandwrittenDigit Recognition》



正则项系数(regularization parameter, λ)

正则项系数初始值应该设置为多少,好像也没有一个比较好的准则。建议一开始将正则项系数λ设置为0,先确定一个比较好的learning rate。然后固定该learning rate,给λ一个值(比如1.0),然后根据validation accuracy,将λ增大或者减小10倍(增减10倍是粗调节,当你确定了λ的合适的数量级后,比如λ = 0.01,再进一步地细调节,比如调节为0.02,0.03,0.009之类。)

在《Neural Networks:Tricks of the Trade》中的第三章『A Simple Trick for Estimating the Weight Decay Parameter』中,有关于如何估计权重衰减项系数的讨论,有基础的读者可以看一下。



Mini-batch size

首先说一下采用mini-batch时的权重更新规则。比如mini-batch size设为100,则权重更新的规则为:

也就是将100个样本的梯度求均值,替代online learning方法中单个样本的梯度值:

当采用mini-batch时,我们可以将一个batch里的所有样本放在一个矩阵里,利用线性代数库来加速梯度的计算,这是工程实现中的一个优化方法。

那么,size要多大?一个大的batch,可以充分利用矩阵、线性代数库来进行计算的加速,batch越小,则加速效果可能越不明显。当然batch也不是越大越好,太大了,权重的更新就会不那么频繁,导致优化过程太漫长。所以mini-batch size选多少,不是一成不变的,根据你的数据集规模、你的设备计算能力去选。

The way to go is therefore to use some acceptable (but not necessarily optimal) values for the other hyper-parameters, and then trial a number of different mini-batch sizes, scaling η as above. Plot the validation accuracy versus time (as in, real elapsed time, not epoch!), and choose whichever mini-batch size gives you the most rapid improvement in performance. With the mini-batch size chosen you can then proceed to optimize the other hyper-parameters.



更多资料

LeCun在1998年的论文《Efficient BackProp》

Bengio在2012年的论文《Practical recommendations for gradient-based training of deep architectures》,给出了一些建议,包括梯度下降、选取超参数的详细细节。

以上两篇论文都被收录在了2012年的书《Neural Networks: Tricks of the Trade》里面,这本书里还给出了很多其他的tricks。



转载请注明出处:http://blog.csdn.net/u012162613/article/details/44265967

机器学习算法中如何选取超参数:学习速率、正则项系数、minibatch size的更多相关文章

  1. 机器学习算法中怎样选取超參数:学习速率、正则项系数、minibatch size

    本文是<Neural networks and deep learning>概览 中第三章的一部分,讲机器学习算法中,怎样选取初始的超參数的值.(本文会不断补充) 学习速率(learnin ...

  2. 机器学习算法中的网格搜索GridSearch实现(以k-近邻算法参数寻最优为例)

    机器学习算法参数的网格搜索实现: //2019.08.031.scikitlearn库中调用网格搜索的方法为:Grid search,它的搜索方式比较统一简单,其对于算法批判的标准比较复杂,是一种复合 ...

  3. 机器学习:调整kNN的超参数

    一.评测标准 模型的测评标准:分类的准确度(accuracy): 预测准确度 = 预测成功的样本个数/预测数据集样本总数: 二.超参数 超参数:运行机器学习算法前需要指定的参数: kNN算法中的超参数 ...

  4. 机器学习算法中的准确率(Precision)、召回率(Recall)、F值(F-Measure)

    摘要: 数据挖掘.机器学习和推荐系统中的评测指标—准确率(Precision).召回率(Recall).F值(F-Measure)简介. 引言: 在机器学习.数据挖掘.推荐系统完成建模之后,需要对模型 ...

  5. 网格搜索与K近邻中更多的超参数

    目录 网格搜索与K近邻中更多的超参数 一.knn网格搜索超参寻优 二.更多距离的定义 1.向量空间余弦相似度 2.调整余弦相似度 3.皮尔森相关系数 4.杰卡德相似系数 网格搜索与K近邻中更多的超参数 ...

  6. 机器学习-kNN-寻找最好的超参数

    一 .超参数和模型参数 超参数:在算法运行前需要决定的参数 模型参数:算法运行过程中学习的参数 - kNN算法没有模型参数- kNN算法中的k是典型的超参数 寻找好的超参数 领域知识 经验数值 实验搜 ...

  7. 机器学习算法中GBDT和XGBOOST的区别有哪些

    首先xgboost是Gradient Boosting的一种高效系统实现,并不是一种单一算法.xgboost里面的基学习器除了用tree(gbtree),也可用线性分类器(gblinear).而GBD ...

  8. 机器学习算法中的评价指标(准确率、召回率、F值、ROC、AUC等)

    参考链接:https://www.cnblogs.com/Zhi-Z/p/8728168.html 具体更详细的可以查阅周志华的西瓜书第二章,写的非常详细~ 一.机器学习性能评估指标 1.准确率(Ac ...

  9. 郑捷《机器学习算法原理与编程实践》学习笔记(第四章 推荐系统原理)(二)kmeans

    (上接第二章) 4.3.1 KMeans 算法流程 算法的过程如下: (1)从N个数据文档随机选取K个文档作为质心 (2)对剩余的每个文档测量其到每个质心的距离,并把它归到最近的质心的类 (3)重新计 ...

随机推荐

  1. 搭建Elasticsearch平台

    https://cloud.tencent.com/developer/article/1189282 https://blog.csdn.net/qq_34021712/article/detail ...

  2. HDU 6088 Rikka with Rock-paper-scissors(NTT+欧拉函数)

    题意 \(n\) 局石头剪刀布,设每局的贡献为赢的次数与输的次数之 \(\gcd\) ,求期望贡献乘以 \(3^{2n}\) ,定义若 \(xy=0\) 则,\(\gcd(x,y)=x+y\) 思路 ...

  3. msvc命令行cl编译c程序问题及解决

    1.cmd命令行cl提示没有这玩意儿 装上Visual Studio之类 2.cl main.c提示缺dll everything搜dll所在路径,在环境配置PATH增加对应bin.IDE 3.cl ...

  4. centos7安装node

    centos7安装node 二进制文件安装 node=v10.13.0 file=node-${node}-linux-x64 wget https://nodejs.org/dist/${node} ...

  5. Pandas 基础(6) - 用 replace() 函数处理不合理数据

    首先, 还是新建一个 jupyter notebook, 然后引入 csv 文件(此文件我已上传到博客园): import pandas as pd import numpy as np df = p ...

  6. vuejs点滴

    博客0.没事的时候可以看的一些博客:https://segmentfault.com/a/1190000005832164 http://www.tuicool.com/articles/vQBbii ...

  7. 【转】 ISP概述、工作原理及架构

    1.概述 ISP全称Image Signal Processing,即图像信号处理.主要用来对前端图像传感器输出信号处理的单元,以匹配不同厂商的图象传感器. ISP 通过一系列数字图像处理算法完成对数 ...

  8. Web版记账本开发记录(四)

    今天已经是是开发软件的第四天了,今天遇到了一些简单的小问题,虽然简单,但是自己仍旧不具备修改的能力, 自己尝试了各种办法仍旧没有修改成功,在收入表就状况百出,错误不断. 我决定明天还是静下心来好好地学 ...

  9. 做h5动画会用到的一个很好的缓动算法库

    http://www.zhangxinxu.com/wordpress/2016/12/how-use-tween-js-animation-easing/

  10. 查看指定库对应GCC版本

    strings /usr/lib/libstdc++.so.6 | grep GLIBCXX