导言

传统的神经网络都是基于固定的数据集进行训练学习的,一旦有新的,不同分布的数据进来,一般而言需要重新训练整个网络,这样费时费力,而且在实际应用场景中也不适用,所以增量学习应运而生。

增量学习主要旨在解决灾难性遗忘(Catastrophic-forgetting) 问题,本文将要介绍的《iCaRL: Incremental Classifier and Representation Learning》一文中对增量学习算法提出了如下三个要求:

a) 当新的类别在不同时间出现,它都是可训练的

b) 任何时间都在已经学习过的所有类别中有很好的分类效果

c) 计算能力与内存应该随着类别数的增加固定或者缓慢增长

有条件的可以去油管听听原作者对这篇论文的讲座:Christoph Lampert: iCaRL- incremental Classifier and Representation Learning

简要概括

本文提出的方法只需使用一部分旧数据而非全部旧数据就能同时训练得到分类器和数据特征从而实现增量学习。

大致流程如下:

1.使用特征提取器\(φ(·)\)对新旧数据(旧数据只取一部分)提取特征向量,并计算出各自的平均特征向量
2.通过最近均值分类算法(Nearest-Mean-of-Examplars) 计算出新旧数据的预测值
3.在上面得到的预测值代入如下loss函数进行优化,最终得到模型。

本文的重点在上面三个步骤中用黑体标出,下面对这三个进行具体介绍

1.平均特征向量

这个其实很好理解,就是把某一类的图像的特征向量都计算出来,然后求均值,注意本文对于旧数据,只需要计算一部分的数据的特征向量。

什么意思呢?

假设我们现在已经训练了\(s-1\)个类别的数据了,记为\(X^1,...,X^{s-1}\),因为通常内存资源有限,所以假设从每个旧数据类中选出一定数量的数据组成examplar sets,记为\(P^1,...,P^{s-1}\)。

然后现在又得到了\(t-s\)个新数据,记为\(X^s,...,X^t\)。同理我们也需要提取出一部分数据,记为\(P^s,...,P^t\)

如何选取数据可参见文末算法示意图

有了新旧数据后,我们可以先将它们合并,记为\(P=\{P^1,...,P^t\}\),然后就可以使用特征提取器\(φ(·)\)计算每个类别的平均特征向量了。

2.最近均值分类算法(Nearest-Mean-of-Examplars classification)

算法第七行在文首给出的讲座中,使用的是\(\|φ(x)-μ_y\|^2\)。 emm... anyway,这不是重点,pass。

3.优化loss函数

机器学习归根到底其实就是优化,那么loss函数如何设定才能解决灾难性遗忘的问题呢?

本文的损失函数定义如下,由新数据分类loss和旧数据蒸馏loss组成。下面公式中的\(g_y(x_i)\)表示分类器,即\(g_y(x)=\frac{1}{1+e^{-{w_y^Tφ(x)}}}\)。

其实该想法其实是基于LWF这篇论文,LWF的loss函数如下:

结果

本文最终结果如下图示,将iCaRL,fixed representation(feature extraction), fine-tuning和LWF进行了比较,可以看到iCaRL表现最好。

讨论

需要说明的是iCaRL和LWF最大的不同点有如下:

  • iCaRL在训练新数据时仍然需要使用到旧数据,而LWF完全不用。所以这也就是为什么LWF表现没有iCaRL好的原因,因为随着新数据的不断加入,LWF逐渐忘记了之前的数据特征。
  • iCaRL提取特征的部分是固定的,只需要修改最后分类器的权重矩阵。而LWF是训练整个网络(下图给出了LWF和fine-tuning以及feature extraction的示意图)。

选取数据算法示意图

MARSGGBO♥原创







2019-1-25

论文笔记系列-iCaRL: Incremental Classifier and Representation Learning的更多相关文章

  1. 论文笔记系列-Neural Architecture Search With Reinforcement Learning

    摘要 神经网络在多个领域都取得了不错的成绩,但是神经网络的合理设计却是比较困难的.在本篇论文中,作者使用 递归网络去省城神经网络的模型描述,并且使用 增强学习训练RNN,以使得生成得到的模型在验证集上 ...

  2. 论文笔记系列-Neural Network Search :A Survey

    论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...

  3. 论文笔记系列-Auto-DeepLab:Hierarchical Neural Architecture Search for Semantic Image Segmentation

    Pytorch实现代码:https://github.com/MenghaoGuo/AutoDeeplab 创新点 cell-level and network-level search 以往的NAS ...

  4. 【论文笔记系列】AutoML:A Survey of State-of-the-art (下)

    [论文笔记系列]AutoML:A Survey of State-of-the-art (上) 上一篇文章介绍了Data preparation,Feature Engineering,Model S ...

  5. 论文解读( N2N)《Node Representation Learning in Graph via Node-to-Neighbourhood Mutual Information Maximization》

    论文信息 论文标题:Node Representation Learning in Graph via Node-to-Neighbourhood Mutual Information Maximiz ...

  6. 论文解读(GMI)《Graph Representation Learning via Graphical Mutual Information Maximization》2

    Paper Information 论文作者:Zhen Peng.Wenbing Huang.Minnan Luo.Q. Zheng.Yu Rong.Tingyang Xu.Junzhou Huang ...

  7. 论文解读(GMI)《Graph Representation Learning via Graphical Mutual Information Maximization》

    Paper Information 论文作者:Zhen Peng.Wenbing Huang.Minnan Luo.Q. Zheng.Yu Rong.Tingyang Xu.Junzhou Huang ...

  8. 论文解读(MVGRL)Contrastive Multi-View Representation Learning on Graphs

    Paper Information 论文标题:Contrastive Multi-View Representation Learning on Graphs论文作者:Kaveh Hassani .A ...

  9. 论文解读(GRCCA)《 Graph Representation Learning via Contrasting Cluster Assignments》

    论文信息 论文标题:Graph Representation Learning via Contrasting Cluster Assignments论文作者:Chun-Yang Zhang, Hon ...

随机推荐

  1. pstack跟踪进程栈

    一:简介 这个命令可以显示每个进程的栈跟踪.pstack命令必须由相应进程的宿主或root运行.可以使用pstack来确定进程挂起的位置.此命令允许使用唯一选项就是进程的PID 这个命令在排查进程问题 ...

  2. socket编程 ------ UDP服务器

    void vLANcommunication( void *pvParameters ) { int32 listenfd; do{ listenfd = socket(AF_INET, SOCK_D ...

  3. Redis分布式锁----悲观锁实现,以秒杀系统为例

    摘要:本文要实现的是一种使用redis来实现分布式锁. 1.分布式锁 分布式锁在是一种用来安全访问分式式机器上变量的安全方案,一般用在全局id生成,秒杀系统,全局变量共享.分布式事务等.一般会有两种实 ...

  4. python机器学习-sklearn挖掘乳腺癌细胞(一)

    python机器学习-sklearn挖掘乳腺癌细胞( 博主亲自录制) 网易云观看地址 https://study.163.com/course/introduction.htm?courseId=10 ...

  5. Java Web之下载文件

    下载的文件,不能随便的被访问,放在外面的文件夹肯定不行,url一敲就能访问了,所以我们要放在WEB-INF文件夹里面,WEB-INF文件夹只有Servlet才能访问,我们新建一个文件夹,叫downlo ...

  6. testlink for windows 安装

    testlink的使用说明可到官网查看:http://www.testlink.org.cn/509.html 一.安装xampp 到xampp官网中下载安装文件,按步骤安装即可. 二.Testlin ...

  7. 微信、支付宝支付SDK

    1.首先是下载SDK,其对应的SDK在mvn上下载不了,需要手动配置到仓库 支付宝SDK下载地址 https://docs.open.alipay.com/54/103419 微信SDK官方下载地址  ...

  8. bzoj千题计划320:bzoj4939: [Ynoi2016]掉进兔子洞(莫队 + bitset)

    https://www.lydsy.com/JudgeOnline/problem.php?id=4939 ans= r1-l1+1 + r2-l2+1 +r3-l3+1 - ∑ min(cnt1[i ...

  9. c/c++gdb下和发布版本下输出地址不同

    相差4字节 相差8个字节 原因: 这4个字节是优化掉了,64位操作系统,函数传参通过寄存器,减少了栈的使用 debug模式下,abc的地址都存下来了.

  10. navicat and connection is being used

    1.在已经保存的连接上上编辑,测试连接成功,但是点击连接就会一直提示 connection is being used 2.需要新建一个连接,才能使用,不能再已保存的上面修改