摘要:我们提出了一种不依赖模型的元学习算法,它与任何梯度下降训练的模型兼容,适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它只需要少量的训练样本就可以解决新的学习任务。在我们的方法中,模型的参数被显式地训练,使得少量的梯度步骤和少量的来自新任务的训练数据能够在该任务上产生良好的泛化性能。实际上,我们的方法训练模型易于微调。结果表明,该方法在两个few shot图像分类基准上都取得了最新的性能,在少镜头回归上取得了良好的效果,并加速了基于神经网络策略的策略梯度强化学习的微调。

1.简介

快速学习是人类智能的一个标志,无论是从几个例子中识别出物体,还是在几分钟的经验之后快速学习新技能。我们的人工智能体也应该能够做到这一点,只需从几个例子中快速学习和适应,并随着更多数据的可用而继续适应。这种快速和灵活的学习是具有挑战性的,因为代理必须集成其先前的经验与少量的新信息,同时避免过拟合到新的数据。此外,以往经验和新数据的形式将取决于任务。因此,为了获得最大的适用性,学习到学习(或元学习)的机制应该是任务和完成任务所需的计算形式。

在这项工作中,我们提出了一个通用的和模型不可知的元学习算法,因为它可以直接应用于任何学习问题和模型的梯度下降过程训练。我们的重点是深入的神经网络模型,但是我们说明了我们的方法如何能够轻松地处理不同的体系结构和不同的问题设置,包括分类、回归和策略梯度强化学习,只需最小的修改。在元学习中,训练模型的目标是从少量的新数据中快速学习一个新的任务,元学习者通过训练模型能够学习大量不同的任务。我们的方法的关键思想是训练模型的初始参数,使得模型在一个新任务上具有最大的性能,在参数已经通过一个或多个梯度步骤更新之后,该梯度步骤由来自该新任务的少量数据计算。与以往学习更新函数或学习规则的元学习方法(Schmidhuber,1987;Bengio等人,1992;Andrychowicz等人,2016;Ravi&Larochelle,2017)不同,我们的算法不会扩展学习参数的数量,也不会对模型架构设置约束(例如,通过要求递归模型(Santoro等人,2016)或暹罗网络(Koch,2015)),它可以很容易地与全连接、卷积或递归神经网络相结合。它还可以用于多种损失函数,包括可微监督损失和不可微强化学习目标。

从特征学习的角度来看,训练一个模型的参数,使几个梯度步骤,甚至一个梯度步骤,能够在一个新任务上产生良好的结果的过程,可以看作是建立了一个广泛适用于许多任务的内部表示。如果内部表示适用于许多任务,只需稍微微调参数(例如,在前馈模型中主要修改顶层权重)就可以产生良好的效果。实际上,我们的过程优化了易于快速微调的模型,允许在正确的空间进行快速学习。从动态系统的观点来看,我们的学习过程可以被看作是最大化新的任务的损失函数相对于参数的灵敏度:当灵敏度高时,对参数的小的局部变化可以导致任务损失的大幅度改善。

这项工作的主要贡献是一个简单的模型和任务无关的元学习算法,它训练模型的参数,使得少量的梯度更新将导致对新任务的快速学习。我们在不同的模型类型上,包括完全连接和卷积网络,以及在几个不同的领域,包括少量镜头回归、图像分类和强化学习,演示了该算法。我们的评估表明,我们的元学习算法在使用较少参数的情况下,与专门为监督分类设计的最先进的一次性学习方法相比,具有优势,但它也可以很容易地应用于回归,并且可以在任务变率存在的情况下加速强化学习,实质上优于初始化时的直接预训练。

2.与模型无关的元学习

我们的目标是训练能够实现快速适应的模型,这种问题设置通常被形式化为少量的镜头学习。在这一节中,我们将定义问题设置并给出我们算法的一般形式。

2.1 元学习问题设置

少镜头元学习的目标是训练一个模型,该模型仅使用少量数据点和训练迭代就能快速适应新任务。为了实现这一点,模型或学习者在元学习阶段接受一组任务的训练,这样训练后的模型可以快速适应新的任务,只需使用少量的例子或试验。实际上,元学习问题把整个任务当作训练的例子。在这一节中,我们以一种通用的方式将这个元学习问题设置形式化,包括不同学习领域的简短示例。我们将在第3节详细讨论两个不同的学习领域。

我们考虑一个表示为f的模型,它将观测值x映射到输出a。在元学习过程中,该模型被训练成能够适应大量或无限数量的任务。由于我们希望将我们的框架应用于各种学习问题,从分类到强化学习,我们在下面介绍了学习任务的一般概念。形式上,每个任务T=fL(x1;a1;::;xH;a H);q(x1);q(xt+1jxt;at);Hg由损失函数L、初始观测值的分布q(x1)、过渡分布q(xt+1jxt;at)和在i.i.d.监督学习问题中的事件长度H组成,长度H=1。模型可以通过在每次t选择一个输出来生成长度H的样本。损失L(x1;a1;::;xH;aH)!R,提供特定于任务的反馈,其形式可能是错误分类损失或马尔可夫决策过程中的成本函数。

图1。我们的模型不可知元学习算法(MAML)的图表,该算法优化表示θ,以快速适应新任务。

在我们的元学习场景中,我们考虑任务p(T)的分布,我们希望我们的模型能够适应。在K-shot学习环境中,该模型仅从qi抽取的K个样本和Ti生成的反馈LTi中学习从p(T)抽取的新任务Ti。在元训练过程中,从p(T)中抽取一个任务Ti,用K个样本对模型进行训练,并从Ti的相应损失LTi中得到反馈,然后在Ti的新样本上进行测试。然后,通过考虑来自qi的新数据的测试误差相对于参数的变化,改进了f模型。实际上,抽样任务Ti上的测试错误是元学习过程的训练错误。在元训练结束时,从p(T)中抽取新任务,在K个样本中学习后用模型的性能来衡量元性能。一般来说,元测试的任务是在元训练期间进行的。

2.2 一种与模型无关的元学习算法

与之前的研究不同,之前的研究试图训练摄取整个数据集的递归神经网络(Santoro等人,2016;Duan等人,2016b)或在测试时可与非参数方法结合的特征嵌入(Vinyals等人,2016;Koch,2015),我们提出了一种通过元学习来学习任何标准模型参数的方法,从而为模型的快速适应做好准备。这种方法背后的直觉是,一些内部表示比其他表示更容易传递。例如,神经网络可以学习广泛适用于p(T)中所有任务的内部特征,而不是单个任务。我们如何才能鼓励出现这样的通用表示?我们对这个问题采取了一种明确的方法:由于模型将使用基于梯度的学习规则对一个新任务进行微调,因此我们将以这样的方式学习一个模型:这种基于梯度的学习规则可以在从p(T)提取的新任务上快速进行,而不会过度拟合。实际上,我们的目标是找到对任务变化敏感的模型参数,使得参数的微小变化将对从p(T)中提取的任何任务的损失函数产生很大的改善,当沿着损失梯度的方向改变时(见图1)。我们对模型的形式没有任何假设,只是假设它是由某个参数向量θ参数化的,并且θ中的损失函数足够光滑,我们可以使用基于梯度的学习技术。

形式上,我们考虑一个参数化函数fθ表示的模型。当适应新任务Ti时,模型的参数θ变为θi0。在我们的方法中,使用任务Ti上的一个或多个梯度下降更新来计算更新的参数向量θi0。例如,当使用一个渐变更新时,

步长α可以固定为超参数或金属化。为了表示简单,我们将在本节的其余部分考虑一个渐变更新,但是使用多个渐变更新是一个简单的扩展。

通过优化fθ0i相对于p(T)抽样任务θ的性能来训练模型参数。具体来说,元目标如下:

注意,元优化是对模型参数θ执行的,而目标是使用更新的模型参数θ0计算的。实际上,我们提出的方法旨在优化模型参数,使得在一个新任务上的一个或少量梯度步骤将在该任务上产生最大有效的行为。

通过随机梯度下降(SGD)进行跨任务的元优化,使得模型参数θ更新如下:

                             (1)

其中β是元步长。在一般情况下,算法1概述了完整的算法。

MAML元梯度更新包括通过梯度的梯度。在计算上,这需要一个额外的向后通过f来计算Hessian向量由TensorFlow等标准深度学习库支持的产品(Abadi等人,2016)。在我们的实验中,我们还包括一个比较来丢弃这个向后传递和使用一阶近似,我们在第5.2节中讨论。

3.MAML种类

在这一节中,我们将讨论我们的元学习算法在监督学习和强化学习中的具体实例。这两个领域在损失函数的形式和任务如何生成数据并呈现给模型方面有所不同,但在这两种情况下都可以应用相同的基本适应机制。

3.1监督回归与分类

在有监督任务领域,很少有镜头学习得到很好的研究,其目标是仅从该任务的几个输入/输出对中学习一个新函数,使用来自类似任务的先验数据进行元学习。例如,目标可能是在只看到Segway的一个或几个示例之后对Segway的图像进行分类,该模型以前见过许多其他类型的对象。同样,在少镜头回归中,目标是在对具有类似统计特性的多个函数进行训练后,仅从从该函数采样的少数数据点预测连续值函数的输出。

为了在第2.1节中的元学习定义的上下文中形式化有监督回归和分类问题,我们可以定义horizon H=1并在xt上删除timestep下标,因为模型接受单个输入并产生单个输出,而不是一系列输入和输出。任务Ti从qi生成K i.i.d.观测x,任务损失由模型的x输出与该观测和任务的相应目标值y之间的误差表示。

用于监督分类和回归的两个常用损失函数是交叉熵和均方误差(MSE),我们将在下面描述;不过,也可以使用其他监督损失函数。对于使用均方误差的回归任务,损失的形式为:

其中x(j);y(j)是从任务Ti采样的输入/输出对。在K-shot回归任务中,为每个任务提供K个输入/输出对用于学习。

同样,对于具有交叉熵损失的离散分类任务,损失的形式为:

               (3)

根据传统的术语,K-shot分类任务使用每个类的K个输入/输出对,对总共NK个数据点进行N向分类。给定任务p(Ti)上的分布,这些损失函数可以直接插入到第2.2节中的方程中以执行元学习,如算法2所述。

3.2强化学习

在强化学习(RL)中,少量元学习的目标是使代理能够使用少量的测试设置经验快速获取新测试任务的策略。一项新的任务可能涉及到实现一个新的目标或在一个新的环境中成功地完成一个先前训练过的目标。例如,一个代理可能学会快速找出如何导航迷宫,这样,当面对新迷宫时,它可以确定如何仅用少量样本可靠地到达出口。在这一节中,我们将讨论如何将MAML应用于RL的元学习。

每个RL任务Ti包含一个初始状态分布qi(x1)和一个转移分布qi(xt+1jxt;at),并且损失LTi对应于(负)报酬函数R。因此,整个任务是一个水平H的Markov决策过程(MDP),允许学习者查询有限数量的样本轨迹以进行少量射击学习。MDP的任何方面都可能在p(T)中跨任务发生变化。正在学习的模型fθ是一种策略,它在每个时间步t 2 f1;::;Hg时从状态xt映射到操作上的分布。任务Ti和模型fφ的损失形式如下

                         (4)

在K-shot强化学习中,fθ和任务Ti(x1;a1;::xH)的K卷展和相应的奖励R(xt;at)可用于适应新的任务Ti。

由于动态未知,期望报酬一般是不可微的,因此我们使用策略梯度方法来估计模型梯度更新和元优化的梯度。由于策略梯度是一种on-policy算法,因此在fθ的自适应过程中,每个额外的梯度步骤都需要从当前策略fθi0中获取新的样本。我们在算法3中详细说明了算法。该算法与算法2的结构相同,主要区别在于步骤5和步骤8需要从与任务Ti相对应的环境中采样轨迹。该方法的实际实现还可以使用最近为策略梯度算法提出的各种改进,包括状态或动作相关基线和信任区域(Schulman等人,2015)。

四.相关工作

本文提出的方法解决了元学习的一般问题(Thrun&Pratt,1998;Schmidhuber,1987;Naik&Mammone,1992),其中包括少量的镜头学习。元学习的一种流行方法是训练元学习者,学习如何更新学习者模型的参数(Bengio等人,1992年;Schmidhuber,1992年;Bengio等人,1990年)。该方法已应用于学习优化深层网络(Hochreiter等人,2001;Andrychowicz等人,2016;Li&Malik,2017),以及学习动态变化的递归网络(Ha等人,2017)。最近的一种方法学习了权重初始化和优化器,用于很少的镜头图像识别(Ravi&Larochelle,2017)。与这些方法不同,MAML学习者的权重是使用梯度而不是学习更新来更新的;我们的方法不引入额外的元学习参数,也不需要特定的学习者架构。

对于生成性建模(Edwards&Storkey,2017;Rezende et al.,2016)和图像识别(Vinyals et al.,2016)等特定任务,也很少开发镜头学习方法。少镜头分类的一个成功方法是学习使用暹罗网络(Koch,2015)在学习的度量空间中比较新的例子,或使用注意机制进行重复(Vinyals等人,2016;Shyam等人,2017;Snell等人,2017)。这些方法产生了一些最成功的结果,但很难直接推广到其他问题,如强化学习。相反,我们的方法对模型的形式和特定的学习任务是不可知的。

元学习的另一种方法是在许多任务上训练记忆模型,在这些任务中,反复学习者被训练以适应新任务的展开。此类网络已应用于少数镜头图像识别(Santoro等人,2016;Munkhdalai&Yu,2017)和学习“快速”强化学习代理(Duan等人,2016b;Wang等人,2016)。实验表明,该方法在fewshot分类上优于递归方法。此外,与这些方法不同,我们的方法只是提供了一个很好的权重初始化,并对学习者和元更新使用相同的梯度下降更新。因此,很容易对学习者进行微调以获得额外的梯度步骤。

我们的方法也与深度网络的初始化方法有关。在计算机视觉中,经过大规模图像分类预训练的模型已经被证明能够学习一系列问题的有效特征(Donahue等人,2014)。相比之下,我们的方法显式地优化了模型以实现快速的适应性,只需几个例子就可以让它适应新的任务。我们的方法也可以被视为显式最大化新任务损失对模型参数的敏感性。许多先前的研究已经探索了深度网络的敏感性,通常是在初始化的背景下(Saxe等人,2014年;Kirkpatrick等人,2016年)。大多数的这些工作都考虑了良好的随机初始化,尽管一些论文已经讨论了依赖于数据的初始化(Krahenb¨uhl et al。¨,2016;Salimans&Kingma,2016),包括学习的初始化(Husken&Goerick,2000;Maclaurin等人,2015)。相比之下,我们的方法明确地训练了给定任务分布的敏感度参数,允许在一个或几个梯度步骤中对诸如K-shot学习和快速强化学习等问题进行非常有效的适应。

5.实验评价

我们的实验评估的目标是回答以下问题:(1)MAML能快速学习新任务吗?(2)MAML可以用于多个不同领域的元学习,包括监督回归、分类和强化学习吗?(3)使用MAML学习的模型能否通过附加的梯度更新和/或示例继续改进?

我们认为所有的元学习问题都需要在测试时对新任务进行一定程度的适应。如果可能,我们将结果与oracle进行比较,oracle接收任务的标识(这是一个问题相关的表示)作为附加输入,作为模型性能的上限。所有的实验都是使用TensorFlow(Abadi等人,2016)进行的,它允许在元学习过程中通过梯度更新进行自动区分。该代码可在线使用1。

5.1.回归

我们从一个简单的回归问题开始,它说明了MAML的基本原理。每项任务都涉及从正弦波的输入到输出的回归,在正弦波的振幅和相位在任务之间是不同的。因此,p(T)是连续的,其中振幅在[0:1;5:0]范围内变化,相位在[0;π]范围内变化,并且输入和输出的维数都为1。在训练和测试过程中,数据点x从–-5:0;5:0均匀采样。损失是预测f(x)和真值之间的均方误差。回归器是一个神经网络模型,有2个隐藏层,大小为40,具有ReLU非线性。在使用MAML进行训练时,我们使用一个K=10的梯度更新示例(步长α=0:01),并使用Adam作为元优化器(Kingma&Ba,2015)底线同样是由亚当训练的。为了评估性能,我们在不同数量的K个例子上对一个元学习模型进行微调,并将性能与两个基线进行比较:(a)对所有任务进行预训练,这需要训练一个网络回归到随机正弦函数,然后在测试时对K个提供的点进行梯度下降微调,使用自动调整的步长,和(b)接收真实振幅和相位作为输入的甲骨文。在附录C中,我们展示了与其他多任务和适应方法的比较。

我们通过微调MAML学习的模型和K=f5;10;20g数据点上的预训练模型来评估性能。在微调过程中,使用相同的K个数据点计算每个梯度步长。定性结果,如图2所示,并在附录B中进一步扩展,表明所学习的模型能够快速适应只有5个数据点,如紫色三角形所示,然而,在所有任务上使用标准监督学习进行预训练的模型,如果没有灾难性的过度拟合,就无法充分适应如此少的数据点。关键的是,当K个数据点都在输入范围的一半时,用MAML训练的模型仍然可以推断出另一半范围内的振幅和相位,这表明用MAML训练的模型f已经学会了模拟正弦波的周期性。此外,我们观察到在定性和定量结果(图3和Appendix B)中,用MAML学习的模型继续以附加的梯度步骤改进,尽管在一个梯度步骤之后被训练以获得最大性能。这一改进表明,MAML优化了参数,使其位于一个易于快速适应的区域,并且对p(T)的损失函数敏感(如第2.2节所述),而不是过度拟合仅在一步后改进的参数θ

图2。对于简单的回归任务很少有镜头适应。左:注意,MAML能够估计曲线中没有数据点的部分,这表明模型已经了解了正弦波的周期结构。右图:对一个模型进行微调,在没有MAML的情况下,对相同的任务分布进行预训练,并调整步长。由于训练前任务的输出往往相互矛盾,该模型无法恢复合适的表示,也无法从少量的测试时间样本中进行外推。

图3。定量正弦回归结果显示了元测试时的学习曲线。注意,在元测试期间,在不过度拟合极小数据集的情况下,MAML继续通过额外的梯度步骤进行改进,从而实现比基线微调方法低得多的损失。

5.2分类

为了比较已有的元学习算法和少镜头学习算法,我们将我们的方法应用于Omniglot(Lake et al.,2011)和minimagenet数据集上的少镜头图像识别。Omniglot数据集由来自50个不同字母表的1623个字符的20个实例组成。每个实例都是由不同的人绘制的。MiniImagenet数据集由Ravi&Larochelle(2017)提出,涉及64个培训班、12个验证班和24个测试班。Omniglot和MiniImagenet图像识别任务是最近使用最普遍的少量镜头学习基准(Vinyals等人,2016;Santoro等人,2016;Ravi&Larochelle,2017)。我们遵循Vinyals等人提出的实验方案。(2016),包括快速学习N向分类,1或5个镜头。N-way分类的问题是:选择N个不可见的类,为模型提供N个类中每个类的K个不同实例,并评估模型在N个类中对新实例进行分类的能力。对于Omniglot,我们随机选择1200个字符进行训练,而不考虑字母表,并使用其余字符进行测试。根据Santoro等人的建议,Omniglot数据集的旋转度增加了90度的倍数。(2016年)。

我们的模型遵循与Vinyals等人使用的嵌入函数相同的架构。(2016),它有4个模块,3×3卷积和64个滤波器,然后是批处理规范化(Ioffe&Szegedy,2015)、ReLU非线性和2×2 max池。Omniglot图像被降采样到28×28,因此最后一个隐藏层的维数为64。在Vinyals等人使用的基线分类器中。(2016),最后一层被送入softmax。对于Omniglot,我们使用跨步卷积而不是max池。对于MiniImagenet,我们每层使用32个过滤器以减少过度拟合,如所做的(Ravi&Larochelle,2017)。为了对记忆增强神经网络(桑托罗等人,2016)也提供公平的比较,并且为了测试MAML的灵活性,我们也为非卷积网络提供结果。为此,我们使用4个隐藏层的网络,每个隐藏层的大小分别为256、128、64、64,包括批处理规范化和ReLU非线性,然后是线性层和softmax。对于所有模型,损失函数是预测类与真类之间的交叉熵误差。附录A.1中包含了额外的超参数细节。

结果见表1。由MAML学习的卷积模型与这项任务的最新结果相比,有很好的性能,远远优于先前的方法。一些现有的方法,如匹配网络,暹罗网络,和记忆模型的设计与少镜头分类铭记,不容易适用于域,如强化学习。此外,与匹配网络和元学习者LSTM相比,使用MAML学习的模型使用更少的总体参数,因为该算法不引入任何超出分类器本身权重的额外参数。与这些先前的方法相比,记忆增强神经网络(Santoro等人,2016)特别是递归元学习模型,代表了一类更广泛适用的方法,如MAML,可用于其他任务,如强化学习(Duan等人,2016b;Wang等人,2016)。然而,如比较所示,在5路泛光分类和minimagenet分类上,MAML在1-shot和5-shot两种情况下都显著优于记忆增强网络和元学习者LSTM。

表1。在突出的Omniglot字符(顶部)和minimagenet测试集(底部)上很少有镜头分类。MAML所获得的结果与最新的卷积和递归模型相当或优于它们。暹罗网、匹配网和内存模块方法都是特定于分类的,不直接适用于回归或RL场景。在任务中,±表示95%的置信区间。请注意,Omniglot的结果可能不具有严格的可比性,因为先前工作中使用的列车/测试分段不可用。基线方法和匹配网络的MiniImagenet评估来自Ravi&Larochelle(2017)。

在MAML中,一个重要的计算开销来自于在通过元目标中的梯度算子反向传播元梯度时使用二阶导数(见方程(1))。在MIN IMANENET中,我们与MAML的一阶近似进行比较,其中省略了这些第二导数。注意,结果方法仍然计算更新后参数值θi0处的元梯度,这提供了有效的元学习。然而,令人惊讶的是,该方法的性能几乎与使用全二阶导数获得的结果相同,这表明MAML中的大多数改进来自于更新后参数值处目标的梯度,而不是通过梯度更新来区分的二阶更新。过去的工作已经观察到,Relu神经网络在局部几乎是线性的(GooFisher等人,2015),这表明,在大多数情况下,二阶导数可能接近于零,部分地解释了一阶近似的良好性能。这种近似消除了在额外的向后传递中计算Hessian向量积的需要,我们发现在网络计算中导致大约33%的加速。

Figure 4. Top: quantitative results from 2D navigation task, Bottom: qualitative comparison between model learned with MAML and with fine-tuning from a pretrained network.

5.3强化学习

为了评估增强学习问题上的MAML,我们基于rllab基准套件中的模拟连续控制环境构建了几组任务(Duan等人,2016a)。我们在下面讨论各个领域。在所有领域中,由MAML训练的模型都是一个具有两个100大小的隐藏层的具有ReLU非线性的神经网络策略。梯度更新使用vanilla policy gradient(REINFORCE)(Williams,1992)计算,我们使用信任区域策略优化(TRPO)作为元优化器(Schulman等人,2015)。为了避免计算三阶导数,我们使用有限差分来计算TRPO的Hessian向量积。对于学习和元学习更新,我们使用Duan等人提出的标准线性特征基线。(2016a),对于批次中的每个采样任务,在每次迭代时分别进行拟合。我们比较了三个基线模型:(a)对所有任务预先训练一个策略,然后进行微调;(b)从随机初始化的权重训练策略;(c)接收任务参数作为输入的oracle策略,对于下面的任务,该策略对应于代理的目标位置、目标方向或目标速度。(a)和(b)的基线模型通过手动调整步长的梯度下降进行微调。学习政策的视频可以在sites.google.com/view/maml上查看。

二维导航。在我们的第一个meta-RL实验中,我们研究了一组任务,其中一个点代理必须在2D内移动到不同的目标位置,在一个单位平方内为每个任务随机选择。观测值是当前的二维位置,动作对应于速度命令,该命令被剪裁为在范围[-0:1;0:1]内。奖励是到目标的负平方距离,当代理在目标的0:01范围内或在H=100的地平线上时,事件终止。使用1个策略梯度更新使用20个轨迹,用MAML训练策略以最大化性能。附录A.2中提供了此问题和以下RL问题的附加超参数设置。在我们的评估中,我们比较了对一个新任务的适应,这个任务最多有4个梯度更新,每个更新有40个样本图4中的结果显示了使用MAML初始化的模型的适应性能、对同一组任务的常规预训练、随机初始化以及接收目标位置作为输入的oracle策略。结果表明,MAML可以学习一个模型,该模型在一次梯度更新中适应得更快,并且随着更新的增加而不断改进。

移动。为了研究MAML如何适应更复杂的深层RL问题,我们还利用MuJoCo模拟器研究了高维运动任务的适应性(Todorov等人,2012)。这些任务需要两个模拟机器人——一个平面猎豹和一个三维四足动物(简称“蚂蚁”)朝特定方向或以特定速度奔跑。在目标速度实验中,奖励是当前代理速度和目标速度之间的负绝对值,猎豹在0:0到2:0之间,蚂蚁在0:0到3:0之间均匀随机选择。在目标方向实验中,奖励是在p(T)中为每项任务随机选择的前进或后退方向的速度大小。horizon是H=200,对于所有问题,每个渐变步骤有20个卷展栏,除了ant forward/backward任务,该任务每个步骤使用40个卷展栏。图5中的结果表明,MAML学习的模型可以快速调整其速度和方向,甚至只需一次梯度更新,并继续改进更多的梯度步骤。结果还表明,在这些具有挑战性的任务中,MAML初始化明显优于随机初始化和预训练。事实上,在某些情况下,预训练比随机初始化更糟,这是在先前的RL工作中观察到的事实(Parisotto等人,2016)。

图5。强化半猎豹和蚂蚁移动任务的学习效果,任务显示在最右侧。与有监督的学习任务不同,每个梯度步骤都需要来自环境的额外样本。结果表明,与传统的预训练和随机初始化相比,多目标学习算法能够更快地适应新的目标速度和方向,仅需两到三个梯度步就能获得良好的性能。我们排除了目标速度和随机基线曲线,因为回报率更差(猎豹低于-200,蚂蚁低于-25)。

6.讨论和今后的工作

提出了一种基于梯度下降学习模型参数的元学习方法。我们的方法有很多好处。它是简单的,不引入任何学习参数的金属学习。它可以与任何适合于梯度训练的模型表示和任何可微目标相结合,包括分类、回归和强化学习。最后,由于我们的方法只产生一个权值初始化,因此可以对任意数量的数据和任意数量的梯度步骤进行自适应,尽管我们仅用一个或五个示例来演示分类的最新结果。我们还表明,我们的方法可以使用策略梯度和非常有限的经验来适应RL代理。

重用过去任务中的知识可能是制作高容量可伸缩模型(如深度神经网络)的一个重要组成部分,该模型能够使用小数据集进行快速训练。我们相信这项工作是朝着一个简单通用的元学习技术迈出的一步,它可以应用于任何问题和任何模型。该领域的进一步研究可以使多任务初始化成为深度学习和强化学习的标准组成部分。

致谢

作者要感谢陈曦和特雷弗·达雷尔进行了有益的讨论,段彦宏和亚历克斯·李提供了技术建议,尼基尔·米什拉、唐浩然和格雷格·卡恩提供了对论文初稿的反馈,以及匿名评论者的评论。这项工作在一定程度上得到了ONR-PECASE奖和NSF-GRFP奖的支持。

A.附加实验细节

在本节中,我们将提供实验设置和超参数的其他详细信息。

A.1. 分类

对于N-way,K-shot分类,每个梯度使用NK示例的批大小计算。对于Omniglot,5路卷积和非卷积MAML模型分别采用1个梯度步长(步长α=0:4)和32个任务的元批处理进行训练。使用3个梯度步骤对网络进行评估,步骤大小α=0:4。对20路卷积MAML模型进行训练,用5个步长α=0:1的梯度步长进行评价。在训练期间,元批大小被设置为16个任务。对于MiniImagenet,两个模型都使用5个α=0:01的梯度步骤进行训练,并在测试时使用10个梯度步骤进行评估。在Ravi&Larochelle(2017)之后,每个类使用15个示例来评估更新后的元梯度。我们分别使用4个和2个任务的元批处理来进行1-shot和5-shot训练。所有模型都在一个NVIDIA Pascal Titan X GPU上训练了60000次迭代。

A.2. 增强学习
在所有强化学习实验中,使用α=0:1的单一梯度步长训练MAML策略。在评估过程中,我们发现在第一个梯度步骤后将学习率减半会产生更好的效果。因此,适应过程中的步长在第一步中设置为α=0:1,在以后的所有步骤中设置为α=0:05。基线方法的步长是为每个域手动调整的。在二维导航中,我们使用20的元批大小;在移动问题中,我们使用40个任务的元批大小。对MAML模型进行了高达500次元迭代的训练,并使用训练期间平均回报率最高的模型进行评估。对于蚂蚁目标速度任务,我们在每个时间步添加一个正奖励奖金,以防止蚂蚁结束该事件。

B.附加正弦结果

在图6中,我们展示了在10次射击学习中训练和在5次射击、10次射击和20次射击中评估的MAML模型的全部定量结果。在图7中,我们展示了随机采样的正弦曲线上MAML和预训练基线的定性性能。

C.其他比较

在本节中,我们将对我们的方法进行更彻底的评估,包括额外的多任务基线和代表Rei(2015)方法的比较。

C.1多任务baseline

正文中的预训练基线训练了一个单一的网络,我们称之为“所有任务的预训练”。为了评估这个模型,就像使用MAML一样,我们使用K个示例在每个测试任务上微调了这个模型。在我们所研究的领域中,不同的任务对于相同的输入涉及不同的输出值。因此,通过对所有任务进行预训练,模型将学习输出特定输入值的平均输出。在某些情况下,该模型可能对实际域了解得很少,而是了解输出空间的范围。

我们尝试了一种多任务方法来提供一个比较点,在这里,我们不是在输出空间中求平均值,而是在参数空间中求平均值。为了获得参数空间的平均值,我们在500个p(T)任务上训练了500个独立的模型。每一个模型都是随机初始化的,并根据分配给它的任务中的大量数据进行训练。然后,我们对模型取平均参数向量,并对5个数据点进行微调,调整步长。由于计算的需要,我们所有的实验都是正弦波的各个回归方程的误差很低:在各自的正弦波上小于0.02。

我们尝试了三种不同的设置。在个体回归者的训练中,我们尝试使用以下方法之一:到目前为止,训练回归者的平均参数向量没有正则化、标准'2权重衰减和'2权重正则化。后两种变体鼓励单个模型找到简洁的解决方案。在使用正则化时,我们将正则化的大小设置为尽可能高,而不会显著地影响性能。在我们的结果中,我们称这种方法为“多任务”。如表2中的结果所示,我们发现参数空间的平均值(多任务)比输出空间的平均值(所有任务的预训练)执行得差。这表明,在分别进行任务训练时,很难找到多个任务的简约解,而且MAML正在学习比平均最优参数向量更复杂的解。

C.2上下文向量自适应

Rei(2015)开发了一种学习可在线调整的上下文向量的方法,并将其应用于递归语言模型。这个上下文向量中的参数是以与MAML模型中的参数相同的方式学习和调整的。为了与使用这种上下文向量解决元学习问题进行比较,我们将一组自由参数z连接到输入x,并且只允许梯度步长修改z,而不是像MAML那样修改模型参数θ。对于图像输入,z与输入图像按通道连接。我们在Omniglot和两个RL域上按照相同的实验协议运行此方法。我们在表3、4和5中报告结果。学习适应性的上下文向量在玩具点质量问题上表现良好,但在更困难的问题上可能是次要的,这可能是由于不太灵活的元优化。

图6。定量正弦波回归结果显示了不同K个测试时间样本的测试时间学习曲线。使用相同的K个例子计算每个梯度步长。注意,在元测试期间,MAML在不过度拟合极小数据集的情况下,继续通过额外的梯度步骤进行改进,并实现了比基线微调方法低得多的损失。

表2。正弦回归域上的附加多任务基线,显示5次均方误差。结果表明,MAML学习的解比平均最优参数向量更复杂。

表3。五向全向分类

表4。二维点质量,平均收益

表5。半猎豹前进/后退,平均回报

图7。正弦回归任务定性结果的随机样本

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(用于深度网络快速适应的元学习)的更多相关文章

  1. 深度学习课程笔记(十七)Meta-learning (Model Agnostic Meta Learning)

    深度学习课程笔记(十七)Meta-learning (Model Agnostic Meta Learning) 2018-08-09 12:21:33 The video tutorial can ...

  2. 论文笔记:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

    Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks ICML 2017 Paper:https://arxiv.org/ ...

  3. 深度学习材料:从感知机到深度网络A Deep Learning Tutorial: From Perceptrons to Deep Networks

    In recent years, there’s been a resurgence in the field of Artificial Intelligence. It’s spread beyo ...

  4. (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning

    Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...

  5. 什么是 Meta Learning / Learning to Learn ?

    Learning to Learn Chelsea Finn    Jul 18, 2017 A key aspect of intelligence is versatility – the cap ...

  6. The Rise of Meta Learning

    The Rise of Meta Learning 2019-10-18 06:48:37 This blog is from: https://towardsdatascience.com/the- ...

  7. 论文笔记:Visual Question Answering as a Meta Learning Task

    Visual Question Answering as a Meta Learning Task ECCV 2018 2018-09-13 19:58:08 Paper: http://openac ...

  8. 【元学习】Meta Learning 介绍

    目录 元学习(Meta-learning) 元学习被用在了哪些地方? Few-Shot Learning(小样本学习) 最近的元学习方法如何工作 Model-Agnostic Meta-Learnin ...

  9. 【MetaPruning】2019-ICCV-MetaPruning Meta Learning for Automatic Neural Network Channel Pruning-论文阅读

    MetaPruning 2019-ICCV-MetaPruning Meta Learning for Automatic Neural Network Channel Pruning Zechun ...

随机推荐

  1. git部分命令笔记

    目录 配置user信息 建Git仓库 清空暂存区 git变更文件名 查看暂存区状态 查看历史 查看本地分支 查看所有分支(包含远程) 创建分支 基于远程分支创建本地新分支 查看图形化分支日志 图形化界 ...

  2. 手把手教你vue配置请求本地json数据

    本篇文章主要介绍了vue配置请求本地json数据的方法,分享给大家,具体如下:在build文件夹下找到webpack.dev.conf.js文件,在const portfinder = require ...

  3. 下载MySQL的rpm包安装MySQL

    cd /usr/local/src wget https://cdn.mysql.com//Downloads/MySQL-5.7/mysql-community-server-5.7.27-1.el ...

  4. jse中将数据反转

    public class test { public static void main(String args[]){ String arr[]={"1","2" ...

  5. 关于hstack和Svstack

    关于hstack和Svstack import numpy as np>>> a = np.array((1,2,3))>>> aarray([1, 2, 3])& ...

  6. vue项目1-pizza点餐系统2-配置路由跳转

    功能目标:点击导航栏中的菜单.主页.路由跳转到不同的组件,点击谁就在在导航栏下展示谁. 1.在router文件夹中(在用脚手架cli搭建项目时,有个couter的选yes)的index.js中,导入如 ...

  7. aria2的安装与配置

    aria2安装 安装 epel 源: yum install epel-release 然后直接安装: yum install aria2 -y 配置 Aria2 创建目录与配置文件 这一步需要切换到 ...

  8. C++ ->error LNK1123

    终极解决方案:VS2010在经历一些更新后,建立Win32 Console Project时会出“error LNK1123” 错误,解决方案为将 项目|项目属性|配置属性|清单工具|输入和输出|嵌入 ...

  9. JS基础知识二

    JS控制语句 switch 语句用于基于不同的条件来执行不同的动作 <script> function myFunction(){ var x; var d=new Date().getD ...

  10. RBAC | YAML |

    YAML配置文件: 1.凡是可以在application.properties配置的文件,都可以在application.yaml文件中配置 2.properties的优先级大于yaml的优先级 后端 ...