【笔记】MAML-模型无关元学习算法
目录
论文信息:
Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.
一、摘要
元学习的目标是在各种学习任务上训练一个模型,这样它就可以使用少量的训练样本来解决新的学习任务。
- 本文提出了一种与模型无关的元学习算法,它适用于任何基于梯度下降进行训练的模型,并且适用于各种学习问题,如分类(Classification)、回归(Regression)和强化学习(Reinforcement Learning)。
- 在本文提出的方法中,模型的参数被显式地训练,模型在处理新任务时,只需几次的梯度更新以及少量的训练数据就能取得良好的泛化性能。
该方法在两种few-shot图像分类基准(Omniglot和 MiniImagenet)上取得了较好的性能,在few-shot回归上取得了较好的效果,并利用神经网络策略加速了策略梯度强化学习的微调。
二、背景
显式训练与隐式训练
参考显函数与隐函数的概念:
- 隐函数:能确定y与x之间关系的方程,F(x,y)=0。x与y混杂在一起。有些隐函数可显化为显函数。
- 显函数:用y=f(x)表示的函数。x与y明显区分。
- 函数是方程,方程不一定是函数。因为函数需要实现一个数域到另一个数域的映射,而方程只要是含有未知数的等式即可。
这样模型参数的显式训练与隐式训练就可以理解为因果区分与因果混杂的情况。
- 隐式训练:没有明确的表达式来对目标参数进行更新。
- 显式训练:存在明确的表达式来更新目标参数。
参数方法与非参数方法
- 参数方法(parametric method):根据先验知识假定模型服从某种分布,然后利用训练集估计出模型参数。这种方法中模型的参数固定,不随数据点的变化而变化。
- 非参数方法(parametric method):基于记忆训练集,在预测新样本值时每次都会重新训练数据,得到新的参数值。参数的数目随数据点的变化而变化。
Hessian Matrix(海森矩阵)
海塞矩阵(Hessian Matrix),又译作海森矩阵,是一个多元函数的
二阶偏导数
构成的方阵。处理一元函数极值问题,如\(f(x)=x^2\) ,我们会先求一阶导数,即 \(f^{\prime}(x)=2x\) ,然后根据
费马定理
——极值点处的一阶导数一定等于 0。但这仅是一个必要条件,而非充分条件。如 \(f(x)=x^3\),显然只检查一阶导数是不足以下定论的。所以进行二次求导,得出以下规律:- 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)<0\) ,则\(f(x)\) 在此点处取得局部极大值;
- 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)>0\) ,则\(f(x)\) 在此点处取得局部极小值;
- 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)=0\) ,则无法确定
处理多元函数极值问题,则需要首先对每个变量求偏导,令其为零,定位极值点的可能位置,然后利用二阶导数判断是极大值还是极小值。\(n\) 元函数有 \(n^2\) 个二阶导数,因此构成海森矩阵:
\[
\mathbf{H}=\begin{bmatrix}
\frac{\partial^2f}{\partial x_1^2} & \frac{\partial^2f}{\partial x_1\partial x_2} & \cdots &\frac{\partial^2f}{\partial x_1\partial x_n} \\
\frac{\partial^2f}{\partial x_2\partial x_1} & \frac{\partial^2f}{\partial x_2^2} & \cdots &\frac{\partial^2f}{\partial x_2\partial x_n} \\ \vdots & \vdots & \ddots & \vdots \\
\frac{\partial^2f}{\partial x_n\partial x_1}&\frac{\partial^2f}{\partial x_n\partial x_2}&\cdots &\frac{\partial^2f}{\partial x_n^2} \end{bmatrix}
\]- 海森矩阵的极值判断阶段如下:
- 如果是正定矩阵,则临界点处是一个局部极小值
- 如果是负定矩阵,则临界点处是一个局部极大值
- 如果是不定矩阵,则临界点处不是极值
- 海森矩阵的极值判断阶段如下:
元学习问题引入
元学习过程实际上是一个创造一个
高级代理
的过程,这个代理在处理新任务的新数据时,能将先验知识整合进来并且能避免过拟合,即在不同任务之间具备泛化能力。高级代理可以理解为创造模型的模型,或者是一组模型参数,它能够根据不同的任务生成不同的模型参数,这套模型参数能够在新任务给定的新数据上快速的学习,适应任务的需要。
为了得到具有快速适应能力的模型,元学习训练一般以
Few-Shot Learning(少样本学习)
的形式进行。Few-Shot,可以分为1~k shot,即在训练过程中提供给模型1~k个样本数据,让模型进行学习。
注意与
Small Sample Learning(SSL,小样本学习)
进行区分。后者的范围比前者更加广泛,具体参见Small Sample Learning in Big Data Era
通过少量的样本数据构建成一个任务,然后让元学习模型在许多依此法创建的任务上进行训练学习,这样,经过训练的元学习模型就能凭借少量的数据和几次的训练快速适应新的任务了。
实际上,元学习模型的训练过程是以一整个一整个的任务作为”训练数据样本“的。
元学习问题的公式化表达
概念定义
定义一个模型,用\(f\)表示。模型\(f\)能实现观察值\(x\)到输出值\(a\)的映射。
定义单个任务\(T\):
\[
\mathcal{T= \left\{ L(\mathrm{x_1,a_1,\dots,x_H,a_H}),q(\mathrm{x_1}),q(\mathrm{x_{t+1}|x_t,a_t}),\mathrm{H}\right\}}
\]- \(\mathcal{L}\)表示损失函数,\(\mathcal{L(\mathrm{x_1,a_1,\dots,x_H,a_H})}\rightarrow \mathbb{R}\)。
- \(\mathcal{q}(\mathrm{x_1})\)表示初始观测变量的分布。
- \(\mathcal{q(\mathrm{x_{t+1}|x_t,a_t})}\)表示转移分布。
- \(\mathrm{H}\)表示跨度(Episode Length),对于i.i.d(独立同分布)监督学习问题,H=1。
期望模型适应的任务的分布\(p(\mathcal{T})\)。
学习过程
- 初始化:随机初始化元学习模型参数\(\theta\),各子任务模型的初始化参数是对\(\theta\)的拷贝。
元训练
- 从\(p(\mathcal{T})\)中抽取任务\(\mathcal{T_i}\);
- 从\(\mathcal{q(i)}\)中抽取\(\mathrm{K}\)个样本;
- 用这\(\mathrm{K}\)个样本对任务\(\mathcal{T_i}\)进行训练,得到相应的损失\(\mathcal{L_{T_i}}\),并对该任务的模型参数进行梯度更新;
- 在新的数据样本上测试更新后的网络,得到错误情况。
元测试
- 根据各个任务更新后的网络的表现(test error)求初始化参数的梯度,并对元学习模型的参数其进行更新;
- 测试其在元测试集任务上的表现,即为元学习模型的最终表现。
三、介绍
本文提出的MAML算法的关键思想:训练模型的初始化参数,使模型能在来自新任务的少量数据上对参数执行数次(1~多次)的梯度更新后能得到最佳的表现。
从特征学习的角度理解——MAML算法试图建立一种模型的内部表示,这种内部表示广泛适用于许多任务。这样在处理新的任务时,只需对模型参数进行简单的微调就能产生较好的结果。
从动态系统的角度理解——MAML的学习过程就是要让新任务的损失函数对于参数的敏感度最大化。当具有较高的敏感度时,参数的微小的局部变化就可以导致任务损失的巨大提升。
动态系统:若系统在t0时刻的响应y(t0),不仅与t0时刻作用于系统的激励有关,而且与区间(-∞,t0)内作用于系统的激励有关,这样的系统称为动态系统。
本文的主要贡献包括以下几个方面:
- 提出了一种元学习的简单模型以及与任务无关的算法,通过训练模型参数,使得模型参数只要经过少量次数的梯度更新就能实现在新任务上的快速学习。
- 在不同的模型,如全连接和卷积网络,以及不同领域上,如少样本回归、图片分类和增强学习上验证了本文提出的算法。
- 本文提出的方法通过使用少量参数,能够与目前最先进的专门用于监督分类的one-shot 学习算法媲美,并且能够应用于回归任务和加速任务可变情况下的强化学习过程。
四、实现
MAML算法的实现直觉(Intuition)是模型的某些内部表示更容易在不同的任务之间转换。比如存在某种内部表示能够适用于任务分布\(\mathcal{p(T)}\)中的所有任务而不是某一个具体的任务。由于最终模型会在新任务上使用基于梯度下降的学习规则进行微调,所以可以以一种
显式的方式
去学习一个具备这种规则的模型。这种待学习的规则可以理解为一组对任务变化敏感的模型参数,当参数沿着任务的损失梯度方向变化时,可以使得任务损失得到较大的改善。
原理图如下:
- \(\theta\) 是已经优化过的模型参数表示。
- 当 \(\theta\) 沿着新任务损失梯度方向变化时,会使得任务损失大幅改善,从而得到对于新任务的最佳模型参数 \(\theta^{\star}\)
算法描述:
模型由函数 \(f_{\theta}\) 表示,该函数由参数 \(\theta\) 决定。
整个算法分为两个循环:
- 两者共享模型参数 \(\theta\) 。
- 两者的梯度更新的学习率分别由超参数 \(\alpha\) 和 \(\beta\) 表示
- 内循环计算各子任务的损失 \(\mathcal{L_{T_{i}}}\) 和进行一至多次梯度更新后的参数 \(\theta^{'}_{i}\) ;
- 外循环根据内循环的优化参数在新任务上重新计算损失,并计算其对初始参数的梯度,然后对初始参数进行
SGD
梯度更新。 - 重复内外循环,就可以得到元学习模型对于任务分布$ \mathcal{p(T)}$的最佳参数。
注意
- 拥有“最佳参数”的模型在处理新任务时,由于具备了先验知识,所以只需进行微调就能产生较好的效果。
- 外循环又称之为
元优化(meta-optimization)
。 - 为了适应不同的任务,内循环中的模型参数会演化成 \(\theta^{\prime}\)。而外循环中模型参数需要等到内循环中的所有任务的模型参数都优化后再进行更新。
- 由于存在一个嵌套关系,外层的梯度更新依赖内层的梯度,因此就会出现二阶导数(梯度的梯度)的计算,需要使用到
海森向量积(Hessian-Vector Product)
。 - 在论文中,作者提出了一种近似算法,利用一阶梯度近似代替二阶梯度,形成FOMAML(First-Order MAML)算法,具体公式推导过程,见MAML讲解-李弘毅。
算法扩展:
监督学习(Supervised Learning):算法中的公式(2)和公式(3)分别指代下面的两个损失函数。
分类(Classification)任务的损失函数采用交叉熵(cross entropy):
\[
\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}y^{(j)}\log f_{\phi}(x^{(j)})+(1-y^{(j)})\log(1-f_{\phi}(x^{(j)}))
\]回归(Regression)任务的损失函数采用均方差(mean-squared error):
\[
\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}\begin{Vmatrix} f_{\phi}(x^{(j)})-y^{(j)}\end{Vmatrix}_2^2
\]
强化学习(Reinforcement Learning):算法中的公式(4)指代下面的损失函数。
强化学习损失函数
- 强化学习过程基于马尔可夫决策过程(Markov Decision Porcess)。
- 具体细节还未深入了解,待补充……
\[
\mathcal{L_{T_i}(f_\phi)}=-\mathbb{E}_\mathcal{x_t,a_t\sim f_\phi,q_{T_i}}[\sum_{t=1}^H R_i(x_t,a_t)]
\]
五、实验
实验代码:
- PyTorch版:https://github.com/dragen1860/MAML-Pytorch
- TensorFlow版:https://github.com/cbfinn/maml
回归(正弦曲线)
- 通过将MAML算法模型与预训练模型比较,分别提供K=5和K=10个样本数据,进行回归拟合。可以看到:
- 在没有提供任何数据点的情况下,MAML由于已经学习到了正弦波的周期结构,所以能够对曲线进行一定的评估;
- 对于预训练模型,由于输出与已学习到的先验知识冲突,导致模型无法找到一个合适的表示形式,从而无法通过少量的样本进行拟合推断。
比较MAML和预训练模型的学习曲线可以得出:
- MAML算法通过少量次数的梯度更新就能实现较低的错误率,没有对少量的数据点过拟合,达到收敛。
- 预训练模型则由于缺乏泛化能力,对与少量数据点,很容易过拟合。
- 通过将MAML算法模型与预训练模型比较,分别提供K=5和K=10个样本数据,进行回归拟合。可以看到:
分类
通过将MAML模型以及简化后的FOMAML模型与用于Few-Shot Learning 分类的主流模型在Omiglot和MiniImagenet数据集上比较,可以发现:
MAML无视数据集差异、数据点多少以及网络结构差异,都有优异的表现。
FOMAML模型的表现与MAML的表现非常接近,但是两者的计算消耗却不同,FOMAML的计算复杂度要明显低于MAML,这一点也是值得进一步研究的问题。
对此,作者推测在大多数情况下,损失函数的二阶导数非常接近零,因而对模型表现没有产生太大的影响。
在On First-Order Meta-Learning Algorithms一文中,作者用泰勒公式,对导数进行了展开分析,揭露了深层次的原因。
强化学习
六、总结
- 提出了一种不引入任何学习参数(实际上增加了学习率\(\alpha 和 \beta\))的通过梯度下降学习模型参数的元学习方法。
- MAML可以以与任何适合于基于梯度的训练的模型表示,以及任何可微分的目标(包括分类、回归和强化学习)相结合。
- MAML只产生一个权值初始化,所以可以使用任意数量的数据和任意数量的梯度步长来执行自适应。
- MAML可以使用策略梯度和非常有限的经验来适应RL代理。
- 重用来自过去任务的知识可能是构建高容量可伸缩模型(如深度神经网络)的一个关键因素,该模型能够使用小数据集进行快速训练。
- 这种元学习技术可以应用于任何问题和任何模型,可以使多任务初始化成为深度学习和强化学习的标准组成部分。
【笔记】MAML-模型无关元学习算法的更多相关文章
- 【笔记】Reptile-一阶元学习算法
目录 论文信息 Nichol A , Achiam J , Schulman J . On First-Order Meta-Learning Algorithms[J]. 2018. 一.摘要 本文 ...
- Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(用于深度网络快速适应的元学习)
摘要:我们提出了一种不依赖模型的元学习算法,它与任何梯度下降训练的模型兼容,适用于各种不同的学习问题,包括分类.回归和强化学习.元学习的目标是在各种学习任务上训练一个模型,这样它只需要少量的训练样本就 ...
- 强化学习(十七) 基于模型的强化学习与Dyna算法框架
在前面我们讨论了基于价值的强化学习(Value Based RL)和基于策略的强化学习模型(Policy Based RL),本篇我们讨论最后一种强化学习流派,基于模型的强化学习(Model Base ...
- 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN
[说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...
- 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能
前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习笔记,这次是第7章 - 利用AdaBoost元算法提高分类性能. 核心思想 在使用某个特定的算法是, ...
- CS229笔记:生成学习算法
在线性回归.逻辑回归.softmax回归中,学习的结果是\(p(y|x;\theta)\),也就是给定\(x\)的条件下,\(y\)的条件概率分布,给定一个新的输入\(x\),我们求出不同输出的概率, ...
- 伯克利、OpenAI等提出基于模型的元策略优化强化学习
基于模型的强化学习方法数据效率高,前景可观.本文提出了一种基于模型的元策略强化学习方法,实践证明,该方法比以前基于模型的方法更能够应对模型缺陷,还能取得与无模型方法相近的性能. 引言 强化学习领域近期 ...
- 【HLSL学习笔记】WPF Shader Effect Library算法解读之[BandedSwirl]
原文:[HLSL学习笔记]WPF Shader Effect Library算法解读之[BandedSwirl] 因工作原因,需要在Silverlight中使用Pixel Shader技术,这对于我来 ...
- Factorization Machines 学习笔记(四)学习算法
近期学习了一种叫做 Factorization Machines(简称 FM)的算法.它可对随意的实值向量进行预測.其主要长处包含: 1) 可用于高度稀疏数据场景:2) 具有线性的计算复杂度.本文 ...
随机推荐
- Qt编写数据可视化大屏界面电子看板系统
一.前言 目前大屏大数据可视化UI这块非常火,趁热也用Qt来实现一个,Qt这个一站式超大型GUI超市,没有什么他做不了的,大屏电子看板当然也不在话下,有了QSS和QPainter这两个无敌的工具组合, ...
- k8s记录-k8s基本概念和术语
每次个节点上当然都要运行Docker.Docker来负责所有具体的映像下载和容器运行. Kubernetes主要由以下几个核心组件组成: etcd保存了整个集群的状态: apiserver提供了资源操 ...
- EasyNVR摄像机网页无插件直播方案H5前端构建之:如何播放HLS(m3u8)直播流
背景描述 HLS (HTTP Live Streaming)是Apple的动态码率自适应技术,主要用于PC和Apple终端的音视频服务,包括一个m3u(8)的索引文件,TS媒体分片文件和key加密串文 ...
- NuxtJS实战,一个博客系统
前言 这个项目诞生于17年5月,距今已有两年多了,在这两年期间经历了很多变更,从简单到复杂,又从复杂到简单,并且以后一直会保持这种简单状态.最近迎来了一次更新,因此特意分享一下.虽然只有我一个人使用( ...
- 在日志中记录Java异常信息的正确姿势
遇到的问题 今天遇到一个线上的BUG,在执行表单提交时失败,但是从程序日志中看不到任何异常信息. 在Review源代码时发现,当catch到异常时只是输出了e.getMessage(),如下所示: l ...
- LeetCode 912. 排序数组(Sort an Array) 43
912. 排序数组 912. Sort an Array 题目描述 每日一算法2019/6/15Day 43LeetCode912. Sort an Array
- vue中ref在input中详解
当我们在项目中遇见文本输入框的时候,获取时刻输入框中的值 1.v-model <template> <input type="text" v-model=&quo ...
- SQL Server 2019 中标量用户定义函数性能的改进
在SQL Server中,我们通常使用用户定义的函数来编写SQL查询.UDF接受参数并将结果作为输出返回.我们可以在编程代码中使用这些UDF,并且可以快速编写查询.我们可以独立于任何其他编程代码来修改 ...
- [转帖]征服诱人的Vagrant!
征服诱人的Vagrant! https://www.cnblogs.com/hafiz/ 一.背景 最近要开始深入学习分布式相关的东西了,那第一步就是在自己的电脑上安装虚拟机,以前在Windows ...
- 递归实现全排列python
python递归实现"abcd"字符串全排列 1.保持a不动,动bcd 2.保持b不动,动cd 3.保持c不动,动d def pailie(head="",st ...