I. 背景介绍

1. 学习曲线(Learning Curve)

我们都知道在手工调试模型的参数的时候,我们并不会每次都等到模型迭代完后再修改超参数,而是待模型训练了一定的epoch次数后,通过观察学习曲线(learning curve, lc) 来判断是否有必要继续训练下去。那什么是学习曲线呢?主要分为两类:

  • 1.模型性能是训练时间或者迭代次数的函数:performance=f(time) 或 performance=f(epoch)。这个也就是我们常用到的方法,即横轴记录训练时间(或迭代次数),纵轴记录准确率。示意图如下:

  • 2.模型性能是用于训练模型的数据集大小的函数:performance=f(train_data_size),即横轴记录训练集大小,纵轴记录准确率。

本文使用的是第一种。

2. 饱和函数(Saturating Function)

饱和函数简单理解就是当输入达到一定值后,输出不再有明显变化,或变化很小,所以称之为饱和。

常见的饱和函数有:

  • 指数函数
  • sigmoid函数

仔细观察可以看到其实理想的学习曲线就是饱和函数。而我们手工判断一个网络是否还有必要继续训练下去的依据,就是看是否已经到了那个临界点,或者我们称之为饱和点

基于此,本文选取了如图所示的11种饱和函数用来拟合学习曲线。

怎么个拟合法呢?这里做个简单的介绍,下文会更加详细介绍。

就是说假如我们通过训练模型和评估模型得到了n个学习曲线值(或者说是准确率),即为\(y_{1:n}=\{y_1,...,y_n\}\)。

然后我们通过使用上面的饱和函数对这些数据进行拟合计算出饱和函数的参数,这样我们就可以预测后面的模型性能了。这样也就可以做到断舍离了。

3. 超参数优化(Hyperparameters Optimization)

假设经过上面的步骤得到了饱和函数的参数,但是我们还是需要对超参数进行采样和优化的。

而常用的超参数优化算法有很多种,其中贝叶斯优化算法是使用最多且较为有效的方法。而基于贝叶斯的优化算法中使用广泛的有如下三种:

  • Spearmint
  • SMAC
  • Tree Parzen Estimator(TPE)

有文章对上面三种方法做了比较,得出如下结论:

  • Spearmint在数值型超参数较少的情况下表现更加好
  • SMAC和TPE在高维超参数和部分离散超参数情况下表现更好

本文使用了后两者优化算法。

4. 无信息先验(Uninformative prior)

由名字就可以很好理解无信息先验是什么意思了,就是说事先我们不知道任何其他信息,所以对于某一事件的先验概率无从得知。这种情况下通常是假设先验概率是均匀分布,因为这样不会对任何事件或元素有偏心。

II. 本文方法

1. Learning Curve Model

假设有K个饱和函数可供选择,记为\(\{f_1,...,f_K\}\),每个饱和函数\(f_i\)都由一组超参数\(θ_i\)决定。

部分观察已知的学习曲线上的值记为\(y_{1:n}\),每个值
\[y_t=f_k(t|θ)+\epsilon \tag{1}\]
其中\(\epsilon\)表示高斯分布的噪声,即\(\epsilon\sim \cal{N}(0,\sigma^2)\).

所以单个观测值\(y_t\)的概率分布是:
\[p(y_t|θ_k,\sigma^2)=\cal{N}(y_t;f_k(t|θ_k),\sigma^2) \tag{2}\]

2. A weighted Probabilistic Learning Curve Model

上面的y值只由一个饱和函数决定,那为什么不能将这些饱和函数都利用起来呢?所以我们可以给每个饱和函数一个权重,因此可以得到如下合并(combined)饱和函数:

\[
f_{comb}(t|\xi) =\sum_{k=1}^Kw_kf_k(t|θ_k) \tag{3}
\]

\[
\xi = (w_1,...,w_K,θ_1,...θ_K,\sigma^2) \tag{4}
\]

所以此时有

\[y_t=f_{comb}(t|\xi)+\epsilon=\sum_{k=1}^Kw_kf_k(t|θ_k)+\epsilon \tag{5}\]
\[p(y_{1:n}|\xi)=\prod_{t=1}^n\cal{N}(y_t;f_{comb}(t|\xi),\sigma^2) \tag{6}\]

现在有了如公式(5)的模型了,所以一个很简单的办法就是根据\(y_{1:n}\)找到该模型的最大似然参数估计,然后将参数带入模型中,这样我们就可以借助这个具体的模型预测后面的模型性能了,从而可以自动决定是否还有必要继续训练下去。但是这样有个缺点,按照原文的说法是:

However, this would not properly model the uncertainty in the model parameters. Since our predictive termination criterion aims at only terminating runs that are highly unlikely to improve on the best run observed so far we need to model uncertainty as truthfully as possible and will hence adopt a Bayesian perspective, predicting values ym using Markov Chain Monte Carlo (MCMC) inference

公式(4)中的参数的先验概率都是无信息先验,另外又因为要确保公式(3)最终还是一个饱和函数,所以规定每个权重都必须为正,参数\(\xi\)的概率分布定义如下:

\[
p(\xi)∝(\prod_{k=1}^Kp(w_k)p(θ_k))p(\sigma^2)\cal{I}(f_{comb}(1|\xi)<f_{comb}(m|\xi)) \tag{7}
\]

其中:

  • \(\cal(I)(f_{comb}(1|\xi)<f_{comb}(m|\xi))\)是一个indicator function,表示合并饱和函数必须是一个增函数。
  • 另外权重\(w_k\)的概率分布为:
    \[
    p(w_k)∝ \begin{cases}
    1 & if \,\,\, w_k >0 \\
    0 & otherwise
    \end{cases} \tag{8}
    \]
    初始化时令\(w_k=\frac{1}{k},k=1,...,K\)

现在有了公式(6)(7)之后,我们就可以得到后验概率:

\[p(\xi|y_{1:n})∝p(y_{1:n}|\xi)p(\xi) \tag{9}\]

借助这个后验概率我们就可以对联合超参数和权重搜索空间\(\xi\)进行MCMC(Markov Chain Monte Carlo) 采样了。那具体怎么个采样法呢?

  • 令所有的模型参数\(θ_k\)设置为其对应饱和函数的最大似然估计。
  • 模型权重均匀分布,即\(w_k=\frac{1}{k}\)
  • 噪声参数初始化为它的最大似然估计,即\(\hat{\sigma}^2=\frac{1}{n}\sum_{t=1}^n(y_t-f_{comb}(t|\xi))^2\)

3. Extrapolate Learning Curve

上面已经将采样的方式介绍清楚了,那么接下来介绍判断是否继续训练模型的细节。

1) 预测模型性能

假设已观察得到的数据为\(y_{1:n}\),那么对于m>n的预测值我们可以用到均值预测,从而减少误差,公式如下:

\[E[y_m|y_{1:n}]≈\frac{1}{S}\sum_{s=1}^Sf_{comb}(m|\xi_s) \tag{10}\]

我们可以设定一个阈值\(\hat{y}\),这个阈值是什么意思呢?就是说假如我们现在有一个分类任务,我觉得分类准确度大于89%就已经很不错了,那么阈值就可以设为89%。那么只要预测的\(y_m\)能够大于或等于0.89,就说明这个还有继续训练下去的必要。

所以下面我们还可以求出\(y_m\)大于阈值的概率分布。

2) 模型性能大于阈值的概率分布

因为对于每一个固定的参数\(\xi\)而言,预测分布\(p(y_m>\hat{y}|y_{1:n})\)都是一个高斯分布,所以有:

\[
\begin{align}
p(y_m>\hat{y}|y_{1:n}) &= \int p(\xi|y_{1:n})p(y_m>\hat{y}|\xi)d\xi \notag \\
&≈\frac{1}{S}\sum_{s=1}^Sp(y_m>\hat{y}|\xi_s) \notag \\
&= \frac{1}{S}\sum_{s=1}^S(1-\Phi(\hat{y};f_{comb}(m|\xi_s),\sigma_s^2)) \tag{11}
\end{align}
\]

其中\(\Phi(·;μ,\sigma^2)\)是均值为μ,方差为\(\sigma^2\)的高斯累计分布函数。

3) 算法细节

  • 预设的最大迭代次数为\(e_{max}\)。
  • 在每个epoch中会记录K次模型在验证集上的性能,即\(y_{1:K}\)
  • 每隔P个epoch,就将之前记录的模型性能汇总起来得到n(\(n=K×P\))个性能数据,即\(y_{1:n}\)。然后根据这些数据预测最后step的性能,即预测第\(m=K×e_{max}\)步的性能\(y_m\)。这里需要用到公式(10)求出预测性能\(y_m\)大于预设性能\(\hat{y}\)的概率,如果概率大于阈值\(δ\),那么继续下一P个epoch的训练。反之则返回预测验证集准确率\(E[y_m|y_{1:n}]\) (由公式(10)求得)。

文中将这一过程称为: predictive termination criterion。

MARSGGBO♥原创







2019-1-5

论文笔记系列-Speeding Up Automatic Hyperparameter Optimization of Deep Neural Networks by Extrapolation of Learning Curves的更多相关文章

  1. 论文笔记:Mastering the game of Go with deep neural networks and tree search

    Mastering the game of Go with deep neural networks and tree search Nature 2015  这是本人论文笔记系列第二篇 Nature ...

  2. 论文笔记——ThiNet: A Filter Level Pruning Method for Deep Neural Network Compreesion

    论文地址:https://arxiv.org/abs/1707.06342 主要思想 选择一个channel的子集,然后让通过样本以后得到的误差最小(最小二乘),将裁剪问题转换成了优化问题. 这篇论文 ...

  3. Neural Networks and Deep Learning 课程笔记(第四周)深层神经网络(Deep Neural Networks)

    1. 深层神经网络(Deep L-layer neural network ) 2. 前向传播和反向传播(Forward and backward propagation) 3. 总结 4. 深层网络 ...

  4. 【论文笔记系列】AutoML:A Survey of State-of-the-art (下)

    [论文笔记系列]AutoML:A Survey of State-of-the-art (上) 上一篇文章介绍了Data preparation,Feature Engineering,Model S ...

  5. 论文笔记系列-Neural Network Search :A Survey

    论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...

  6. 《Improving Deep Neural Networks:Hyperparameter tuning, Regularization and Optimization》课堂笔记

    Lesson 2 Improving Deep Neural Networks:Hyperparameter tuning, Regularization and Optimization 这篇文章其 ...

  7. 论文笔记系列-Auto-DeepLab:Hierarchical Neural Architecture Search for Semantic Image Segmentation

    Pytorch实现代码:https://github.com/MenghaoGuo/AutoDeeplab 创新点 cell-level and network-level search 以往的NAS ...

  8. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  9. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

随机推荐

  1. 点赞功能与redis

    转:https://edu.aliyun.com/a/20538 摘要: 前言点赞其实是一个很有意思的功能.基本的设计思路有大致两种, 一种自然是用mysql等数据库直接落地存储, 另外一种就是利用点 ...

  2. springcloud使用zookeeper作为config的配置中心

    https://blog.csdn.net/CSDN_Stephen/article/details/78856323 仓库更新了,本地如何更新: 使用configserver作为配置中心: http ...

  3. npx小工具

    npm v5.2.0引入的一条命令(npx),引入这个命令的目的是为了提升开发者使用包内提供的命令行工具的体验. 举例:使用create-react-app创建一个react项目. 老方法: npm ...

  4. 《老梁四大名著情商课》笔记-学学TA,你就是聚会的万人迷

    <老梁四大名著情商课>笔记-学学TA,你就是聚会的万人迷 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 现在社会学家有一个统计,说中国处在单身状态大概有2个亿.这些人中 ...

  5. CentOS7 上以 RPM 包方式安装 Oracle 18c 单实例

    安装阿里云 YUM 源 https://opsx.alibaba.com/mirror?lang=zh-CN 一.安装Oracle数据库 1.安装 Oracle 预安装 RPM yum -y loca ...

  6. SVN快速入门笔记【转】

    1. SVN版本控制软件目的 协作开发 远程开发 版本回退 2. 什么是SVN subVersion 支持平台操作 支持版本回退 3. 获取SVN软件 属于C/S结构软件(客户端与服务端) serve ...

  7. Linux 下装逼技巧

    ``` 1.下载cmatrix-1.2a.tar.gz文件 [root@localhost ~]# wget https://jaist.dl.sourceforge.net/project/cmat ...

  8. flask异步

    demo def runFlask(port): init() app.config[' app.run(port=port, threaded=True) CORS(app, supports_cr ...

  9. JDK源码之ArrayList

    序言 ArrayList底层通过数组实现. ArrayList即动态数组,实现了动态的添加和减少元素 需要注意的是,容量拓展,是创建一个新的数组,然后将旧数组上的数组copy到新数组,这是一个很大的消 ...

  10. Android弹出窗口

    protected void PopUp() { final PopupWindow popup = new PopupWindow(TestActivity.this); View popView ...