【笔记】Reptile-一阶元学习算法
目录
论文信息
Nichol A , Achiam J , Schulman J . On First-Order Meta-Learning Algorithms[J]. 2018.
一、摘要
本文主要考虑元学习
问题,即存在一个任务分布(a distribution of tasks),从这个分布中抽取许多任务来训练元学习模型(或代理),使其在处理从这个分布中抽取的以前从未遇到过的任务时能更快的学习(即表现得更好)。
本文通过分析一系列仅在元学习更新(meta-learning update)过程中使用一阶微分(first-order derivation)就能在新任务上实现快速微调的关于参数初始化的算法,验证了一阶元学习算法在一些完善的few-shot分类基准上的有效性,同时还对这些算法的可行性进行了理论分析。
这些一阶元学习算法主要包括MAML的近似表示(忽略二阶微分)——first-order MAML(简记:FOMAML)以及本文提出的Reptile算法。
二、背景
2.1 雅可比矩阵(Jacobi Matrix)
是函数的一阶偏导数以一定方式排列成的矩阵,其体现了一个可微方程与给出点的最优线性逼近。
- 假设\(F:\mathbb{R}_\mathrm{n}\rightarrow \mathbb{R}_\mathrm{m}\)是一个从n维欧氏空间映射到到m维欧氏空间的函数。这个函数由m个实函数组成:
\[
F=\begin{cases}
f_1(x_1,\cdots,x_n)\\
f_2(x_1,\cdots,x_n)\\
\cdots\\
f_m(x_1,\cdots,x_n)\\
\end{cases}
\]
- 这些函数的偏导数(如果存在)可以组成一个m行n列的矩阵,这个矩阵就是所谓的雅可比矩阵:
\[
J = [\frac{\partial f}{\partial x_1} \cdots \frac{\partial f}{\partial x_n}]
=\begin{bmatrix}
\frac{\partial f_1}{\partial x_1} & \cdots & \frac{\partial f_1}{\partial x_n} \\
\vdots & \ddots & \vdots \\
\frac{\partial f_m}{\partial x_1} & \cdots & \frac{\partial f_m}{\partial x_n}
\end{bmatrix}
\]
2.2 泰勒公式
\[
\begin{array}{*{20}{l}}
{f{ \left( {x} \right) }{\begin{array}{*{20}{l}}
{=f{ \left( {\mathop{{x}}\nolimits_{{0}}} \right) }+{f \prime }{ \left( {\mathop{{x}}\nolimits_{{0}}} \right) }{ \left( {x-\mathop{{x}}\nolimits_{{0}}} \right) }+\frac{{f '' { \left( {\mathop{{x}}\nolimits_{{0}}} \right) }}}{{2!}}\mathop{{ \left( {x-\mathop{{x}}\nolimits_{{0}}} \right) }}\nolimits^{{2}}+ \cdots +\frac{{\mathop{{f}}\nolimits^{{ \left( {n} \right) }}{ \left( {\mathop{{x}}\nolimits_{{0}}} \right) }}}{{n!}}\mathop{{ \left( {x-\mathop{{x}}\nolimits_{{0}}} \right) }}\nolimits^{{n}}+\mathop{{R}}\nolimits_{{n}}{ \left( {x} \right) }}\\
\end{array}}}\\
\end{array}
\]
2.3 领头阶(Leading Order)
一个解析表达式按照泰勒公式成无穷级数(或者多项式),根据所研究的定义域,每一个展开项所贡献的大小是不会都相同的,根据它们对解析表达式精确值的贡献大小将这些项分门别类地叫做领头阶、次领头阶、次次领头阶…
2.4 转导与归纳
摘自:维基百科(https://en.wikipedia.org/wiki/Transduction_(machine_learning))
- 转导(Transduction):从观察到的特定(训练)案例到特定(测试)案例的推理。
- 归纳(Induction):从观察到的训练案例到一般规则的推理,然后将其应用于测试案例。
示例:
给出一个点的集合,其中一些点被标记了为A,B或C,但是大多数点没有被标记,用?表示。训练的目的是预测所有未标记点的“最佳”标签。
采用归纳的思想,是使用有标记的点来训练监督学习算法,然后让其预测所有未标记的点的标签。但是,对于这个问题,监督学习算法将仅具有五个标记点,建立捕获该数据结构的模型肯定会很困难。例如,如果使用最近邻居算法,则即使很明显可以看到中间附近的点与标记为“ B”的点属于同一个群集,也有可能会被标记为“ A”或“ C”。
- 转导在执行标记任务时,能够考虑所有点,而不仅仅是标记点。在这种情况下,转导算法将根据它们原本所属的簇来标记未标记的点。因此,中间的点很可能会标记为“ B”,因为它们的位置非常靠近该群集。
- 转导的一个优势是,它可以使用较少的标记点来进行更好的预测,因为它使用了未标记点中的自然隔断(Break)。
- 转导的一个缺点是它没有建立预测模型。如果将先前未知的点添加到集合中,则需要对所有点重复整个转换算法,以预测标签。如果数据在流式的数据中逐渐可用,则在计算上可能会很昂贵。此外,这可能会导致某些旧点的预测发生变化(取决于应用程序可能是好是坏)。另一方面,有监督的学习算法可以立即标记新点,而计算成本却很少。
三、介绍
3.1 算法动机
- 人类在进行一项新的任务时,通常使用了大量编码于人类大脑和DNA中的先验知识。得益于此,人类具有快速学习的能力,在数学上这种能力的获得可以解释为贝叶斯推断(Bayesian Inference)过程,这也正是开发出能达到人类水平的学习速度的算法的关键。但实际上使用深度神经网络开发出计算上可行的贝叶斯机器学习算法是极具挑战的。
- 与此不同,元学习算法并没有尝试去模拟贝叶斯推断过程,而是试图使用任务数据集直接优化快速学习算法,这种算法作为一种“代理”,能够在新任务上快速适应并学习。两类常见的元学习方法:
- 基于模型:将学习算法编码为循环网络模型中的权重,从而在训练过程中对元学习模型的参数进行更新。
- 基于初始化:
- pre-training:在大量数据上(ImageNet)上学习网络的初始化参数,然后在新任务上进行测试时对这些参数进行微调。这种方法无法保证得出的参数便于调整,为了达到良好的性能有时还需要一些特殊的技巧(ad-hoc tricks)。
- MAML:在优化过程中对初始化参数进行微分更新,以获得一个敏感的基于梯度的学习算法。但是这种算法使用了二阶微分计算,增大了计算开销。
- FOMAML:作为MAML的变种,忽略了二阶微分项,节省了计算开销,但损失了部分梯度信息。
- 针对某些问题使用依赖于高阶梯度的技术可能出现的复杂性,本文探讨了基于一阶梯度信息的元学习算法。
3.2 本文贡献
- 指出FOMAML的实现相比以前的认知更加容易。
- 提出了Reptile算法。这种算法与联合训练(joint training,通过训练来最小化在一系列训练任务上期望损失)非常相似,并且与FOMAML紧密相关,但是与FOMAML不同,Reptile无需对每一个任务进行训练-测试(training-testing)划分。
- 对FOMAML和Reptile进行了理论分析,表明两者都对任务内泛化进行了优化。
- 在对Mini-ImageNet和Omniglot数据集进行实证评价的基础上,提出了实施最佳实践的一些见解。
四、实现
4.1 FOMAML简化实现
MAML优化过程的公式化表示:
\[
\min_{\phi}\mathbb{E}_{\mathcal{T}}[L_{\mathcal{T},B}(U_{\mathcal{T},A}(\phi))]
\]对于给定的任务\(\mathcal{T}\),内循环中使用训练样本\(A\) 进行优化,然后使用测试样本 \(B\) 计算得到损失,外循环使用损失对初始化参数求梯度,即可得出新任务上参数的优化方向。
\[
g_{MAML}=\frac{\partial L_{\mathcal{T},B}(U_{\mathcal{T},A}(\phi))}{\partial \phi}=L^{\prime}_{\mathcal{T},B}(\tilde{\phi})U^{\prime}_{\mathcal{T},A}(\phi), \qquad where \quad \tilde{\phi}=U_{\mathcal{T},A}(\phi))
\]其中 \(U^{\prime}_{\mathcal{T},A}(\phi)\) 可以视为是关于 \(U_{\mathcal{T},A}(\phi)\) 的雅可比矩阵,而 \(U_{\mathcal{T},A}(\phi)\) 可以视为是对初始化参数向量累加了一系列的梯度向量, \(U_{\mathcal{T},A}(\phi)=\phi+ g_1 + g_2 + \dots +g_k\) 。
FOMAML的简化:
将梯度向量视为常量,即可将雅可比矩阵转化为恒等操作(identity operation),所以可以简化外循环优化过程中所使用的梯度公式。
\[
g_{FOMAML}=L^{\prime}_{\mathcal{T},B}(\tilde{\phi})
\]具体流程如下:
- 采样任务\(\mathcal{T}\) ;
- 对初始化参数执行更新操作,得到\(\tilde{\phi}=U_{\mathcal{T},A}(\phi))\);
- 利用 \(\tilde{\phi}\) 计算对 \(\phi\) 的梯度,得到 \(g_{FOMAML}=L^{\prime}_{\mathcal{T},B}(\tilde{\phi})\)
- 将\(g_{FOMAML}\) 应用到外部循环优化中。
4.2 Reptile实现
算法描述
算法最后一步的模型参数更新的batch版本,可以写为如下形式:
\[
\phi \leftarrow \phi +\epsilon \frac{1}{n} \sum_{i=1}^{n}(\tilde{\phi_i}-\phi)
\]其中\(\tilde{\phi_i}=U^{k}_{\mathcal{T}_i} \left\{ \phi \right\}\) ,表示在第i个任务上对参数的更新操作。
这个算法与在损失期望上进行的联合训练十分相似。
当k=1时,算法对应于期望损失的随机梯度下降(SGD)。
\[
\begin{align}
g_{Reptile,k=1} & =\mathbb{E}_{\mathcal{T}}\mathrm{[\phi-U_{\mathcal{T}}(\phi)]/\alpha}\\
& =\mathbb{E}_{\mathcal{T}}\mathrm{[\nabla_{\phi}L_{\mathcal{T}}(\phi)]}
\end{align}
\]当k>1时,更新过程包含了\(L_{\mathcal{T}}\) 的二阶乃至更高阶的微分项。
4.3 理论分析
更新过程中的领头阶(Leading Order)展开
直觉是:
使用泰勒序列展开来近似表示Reptile与MAML的更新过程,发现两者具有相同的领头项(leading-order terms)——领头阶(第一项)起着最小化期望损失的作用;次领头项(第二项)及后续项最大化任务内的泛化性。
最大化同一任务中不同minibatch之间梯度的内积,对其中一个batch进行梯度更新会显著改善另一个batch的的表现。
表达式定义(\(i\in[1,k]\) 指代不同的batch)
\[
\begin{align}
&g_i=L^{\prime}_i(\phi_{i})\quad(在SGD过程中获得的梯度)\\
&\phi_{i+1}=\phi_i-\alpha g_i\quad(参数更新序列)\\
&\bar{g_i}=L^{\prime}_i(\phi_1)\quad (起始点梯度)\\
&\bar{H_i}=L^{\prime \prime}_i(\phi_1)\quad (起始点Hessian矩阵,即二阶梯度)
\end{align}
\]将SGD过程中获得的梯度,按照泰勒公式展开
近似表示MAML梯度(\(U_i\) 表示在第\(i\)个minibatch上对参数向量的更新操作)
领头阶展开
当k=2时,三者的一般表示形式为:
\[
\begin{align}
&g_{MAML}=\bar{g_2}-\alpha\bar{H_2}\bar{g_1}-\alpha\bar{H_1}\bar{g_2}+O(\alpha^2)\\
&g_{MAML}=g_2=\bar{g_2}-\alpha\bar{H_2}\bar{g_1}+O(\alpha^2)\\
&g_{Reptile}=g_1+g_2=\bar{g_1}+\bar{g_2}-\alpha\bar{H_2}\bar{g_1}+O(\alpha^2)\\
\end{align}
\]其中:
- 类似于\(\bar{g_1}\quad \bar{g_2}\)的项就是领头项,用于最小化联合训练损失;
- 类似于\(\bar{H_2}\bar{g_1}\)的项就是次领头项,作用是最大化不同批次数据上得到的梯度的内积。
在进行minibatch采样,取三种梯度的期望时,上述两种领头项分别用AvgGrad和AvgGradInner表示(k=2):
三种算法梯度的期望表示形式可以化为:
扩展到k>2的情况有:
- 可以看到三者AvgGradInner与AvgGrad之间的系数比的关系是:MAML > FOMAML > Retile。
- 这个比例与步长\(\alpha\),迭代次数\(k\) 正相关。
找到一个接近所有解流形(Solution Manifolds)的点
直觉:
Reptile收敛于一个解,这个解在欧式空间上与每个任务的最优解的流形接近。
用 \(\phi\) 表示网络初始化,\(\mathcal{W_{T}}\) 表示任务\(\mathcal{T}\)上的最优参数集。优化过程的最终目标是找到一个\(\phi\)使得其与所有任务的\(\mathcal{W_{T}}\) 之间的距离最小。
\[
\min_{\phi}\mathbb{E}_{\mathcal{T}}[\frac{1}{2} D(\phi,\mathcal{W_T})^2]
\]对参数\(\phi\)的梯度为:
在Reptile中每一次迭代相当于采样一个任务然后在上面执行一侧SGD更新。
实际情况下,很难直接计算出\(P_{\mathcal{W_T}}(\phi)\),即使得\(L_T\) 取得最小值的p。因此在Reptile中,用初始化参数\(\phi\)在\(L_T\) 上执行k步梯度下降后得到的结果来代替最优化参数\(\mathcal{W^{\star}_{T}(\phi)}\)。
五、实验
5.1 少样本分类
Few-Shot Classification(少样本分类)是少样本学习中的一类任务,在这类任务中,存在一个元数据集(Meta-Data Set),包含了许多类的数据,每类数据由若干个样本组成,这种任务的训练通常与K-Shot N-way分类任务绑定在一起,具体理解参见《关于N-Way K-Shot 分类问题的理解》。
建立与MAML一样的CNN训练模型,在Ominglot和MiniImageNet数据集上进行训练与测试,实验结果如下:
从两个表格中的数据可以看出,MAML与Reptile在加入了转导(Transduction)后,在Mini-ImageNet上进行实验,Reptile的表现要更好一些,而Omniglot数据集上正好相反。
5.2 不同的内循环梯度组合比较
通过在内循环中使用四个不重合的Mini-Batch,产生梯度数据\(g_1,g_2,g_3,g_4\) ,然后将它们以不同的方式进行线性组合(等价于执行多次梯度更新)用于外部循环的更新,进而比较它们之间的性能表现,实验结果如下图:
从曲线可以看出:
- 仅使用一个批次的数据产生的梯度的效果并不显著,因为相当于让模型用见到过的少量的数据去优化所有任务。
- 进行了两步更新的Reptile(绿线)的效果要明显不如进行了两步更新的FOMAML(红线),因为Reptile在AvgGradInner上的权重要小于FOMAML。
- 随着mini-batch数量的增多,所有算法的性能也在提升。通过同时利用多步的梯度更新,Reptile的表现要比仅使用最后一步梯度更新的FOMAML的表现好。
5.3 内循环中Mini-Batch 重合比较
Reptile和FOMAML在内循环过程中都是使用的SGD进行的优化,在这个优化过程中任何微小的变化都将导致最终模型性能的巨大变化,因此这部分的实验主要是探究两者对于内循环中的超数的敏感性,同时也验证了FOMAML在minibatch以错误的方式选取时会出现显著的性能下降情况。
mini-batch的选择有两种方式:
- shared-tail(共尾):最后一个内循环的数据来自以前内循环批次的数据
- separate-tail(分尾):最后一个内循环的数据与以前内循环批次的数据不同
采用不同的mini-batch选取方式在FOMAML上进行实验,发现随着内循环迭代次数的增多,采用分尾方式的FOMAML模型的测试准确率要高一些,因为在这种情况下,测试的数据选取方式与训练过程中的数据选取方式更为接近。
当采用不同的批次大小时,采用共尾方式选取数据的FOMAML的准确性会随着批次大小的增加而显著减小。当采用full-batch时,共尾FOMAML的表现会随着外循环步长的加大而变差。
共尾FOMAML的表现如此敏感的原因可能是最初的几次SGD更新让模型达到了局部最优,以后的梯度更新就会使参数在这个局部最优附近波动。
六、总结
Reptile有效的原因有二:
- 通过用泰勒级数近似表示更新过程,发现SGD自动给出了与MAML计算的二阶项相同的项。这一项调整初始权重,以最大限度地增加同一任务中不同小批量梯度之间的点积,从而增大模型的泛化能力。
- Reptile通过利用多次梯度更新,找到了一个接近所有最优解流形的点。
当执行SGD更新时,MAML形式的更新过程就已经被自动包含在其中了,通过最大化模型在不同批次数据之间的泛化能力,从而使得模型在微调(fine-tune)时能取得显著的效果。
【笔记】Reptile-一阶元学习算法的更多相关文章
- 【笔记】MAML-模型无关元学习算法
目录 论文信息: Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networ ...
- Factorization Machines 学习笔记(四)学习算法
近期学习了一种叫做 Factorization Machines(简称 FM)的算法.它可对随意的实值向量进行预測.其主要长处包含: 1) 可用于高度稀疏数据场景:2) 具有线性的计算复杂度.本文 ...
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(用于深度网络快速适应的元学习)
摘要:我们提出了一种不依赖模型的元学习算法,它与任何梯度下降训练的模型兼容,适用于各种不同的学习问题,包括分类.回归和强化学习.元学习的目标是在各种学习任务上训练一个模型,这样它只需要少量的训练样本就 ...
- 再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT)
再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Bluestein算法+分治FFT+FFT的优化+任意模数NTT) 目录 再探快速傅里叶变换(FFT)学习笔记(其三)(循环卷积的Blueste ...
- [置顶] 生成学习算法、高斯判别分析、朴素贝叶斯、Laplace平滑——斯坦福ML公开课笔记5
转载请注明:http://blog.csdn.net/xinzhangyanxiang/article/details/9285001 该系列笔记1-5pdf下载请猛击这里. 本篇博客为斯坦福ML公开 ...
- OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波
http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 201 ...
- 机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记
机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记 关键字:k-均值.kMeans.聚类.非监督学习作者:米仓山下时间: ...
- CS229笔记:生成学习算法
在线性回归.逻辑回归.softmax回归中,学习的结果是\(p(y|x;\theta)\),也就是给定\(x\)的条件下,\(y\)的条件概率分布,给定一个新的输入\(x\),我们求出不同输出的概率, ...
- Alink漫谈(十二) :在线学习算法FTRL 之 整体设计
Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 目录 Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 0x00 摘要 0x01概念 1.1 逻辑回归 1.1.1 推导过程 ...
随机推荐
- 抓包 抓nodejs的包 抓浏览器的包 抓手机的包
应用场景: 确认接口是能用的,但自己使用时就是不行,参数有没有传正确?格式对不对?傻傻分不清. 抓包工具:这里演示 charles , 常用的还有 Fiddler, HttpWatch, WireSh ...
- Nginx修改时间戳
1.安装nginx,注意不要安装nginx-common或者nginx-full sudo apt-get install nginx sudo apt-get install nginx-commo ...
- VUE 同一页面路由参数变化,视图不刷新的解决方案
1.监听路由处理 watch: { $route(to, from) { // 逻辑 // 重新调用数据接口 } }, 2.beforeRouteUpdate导航守卫 路由更新时触发 beforeRo ...
- C# .NET “公钥证书” (.cer .pem)转换为 RSACryptoServiceProvider 对象。导出“公钥”
“公钥证书” .cer 文件是直接可以用X509Certificate2 对象来读取的,但 .cer 文件 不便于存储. “公钥证书” .pem 文件内容如下: -----BEGIN CERTIFIC ...
- tp中model加载机制
$user_model = D('User'); 如果当前模块下面有UserModel,就优先使用当前模块下的UserModel.如果当前模块下没有UserModel,就回去Common模块下找Use ...
- JAVA 扫描指定路径下所有的jar包,并保存所有实现固定接口的类型
private static Map<String, Object> loadAllJarFromAbsolute(String directoryPath) throws NoSuchM ...
- graph处理工具
仅作为记录笔记,完善中...................... 1 PyGSP https://pygsp.readthedocs.io/en/stable/index.html ht ...
- Idea Spring 、SpringBoot相关设置技巧
1.Spring变量依赖注入出现红色波浪线 Could not autowire. No beans of 'UserMapper' type found. less... (Ctrl+F1) Che ...
- 0-python变量及基本数据类型
目录 1.变量2.字符串3.布尔类型4.整数5.浮点数6.日期 1.变量 1.1.变量的定义 - 类似于标签 1.2.变量的命名规则 - (强制)变量名只能包含数字.字母.下划线 - (强制)不能以数 ...
- Python爬虫:现学现用xpath爬取豆瓣音乐
爬虫的抓取方式有好几种,正则表达式,Lxml(xpath)与BeautifulSoup,我在网上查了一下资料,了解到三者之间的使用难度与性能 三种爬虫方式的对比. 这样一比较我我选择了Lxml(xpa ...