Neyshabur B., Sedghi H., Zhang C. What is being transferred in transfer learning?

arXiv preprint arXiv 2008.11687, 2020.

迁移学习到底迁移了什么?

主要内容

  • T: 普通训练的模型
  • P: 预训练的模型
  • RI: 随机初始化的模型
  • RI-T: 随机初始化再经过普通训练的模型
  • P-T: 在预训练的基础上再fine-tuning的模型

本文的预训练都是在ImageNet上, 然后在CheXpert和DomainNet(分为real, clipart, quickdraw)上测试.

feature reuse

大家认为迁移学习有用的一个直觉就是迁移学习通过特征的复用来样本少的数据提供一个较好的特征先验.



通过上面的图可以看到, P-T总是能够表现优于RI-T, 这能够支撑我们的观点. 但是, 为什么数据差别特别大的时候, 预训练还是有用呢(此时feature reuse的作用应该不是很明显)? 作者将图片按照不同的block size打乱(就像最开始的那些乱七八糟的图片). 这个时候, 模型应该只能抓住浅层的特征, 抽象的特征是没法被很好提取的, 结果如下图所示.

  • 当打乱的程度加剧(block size变小), 任务越发困难;
  • 相对正确率差距\((A_{P-T}-A_{RI-T})/A_{P-T} \%\)随着block size减小而减小(clipart, real), 这说明feature reuse很有效果, quickdraw 相反是由于其数据集和预训练的数据集相差过大, 但是即便如此, 在quickdraw上预训练还是有效的, 说明存在除了feature reuse外的因素;
  • P-T的训练速度(右图)一直很稳定, 而RI-T的训练速度则在block size下降的时候有一个急剧的下降, 这说明feature reuse并不是影响P-T训练速度的主要因素.

mistakes and feature similarity

这部分通过探究不同模型有哪些common和uncommon的mistakes来揭示预训练的作用.

P-T在简单样本上的成功率很高, 而在比较模糊难以判断的样本上比较难(而此时RI-T往往比较好), 这说明P-T有着很强的先验.

通过 centered kernel alignment (CKA) 来衡量特征之间的相似度:



可以发现, 基于预训练的模型之间的特征相似度很高, 而RI-T与别的模型相似度很低, 即便是两个相同初始化的RI-T. 说明预训练模型之间往往是在重复利用相同的特征.

下表为不同模型的参数的\(\ell_2\)距离, 同样能够反映上面一点.

loss landscape

用\(\Theta, \tilde{\Theta}\)表示两个checkpoint的参数, 通过线性插值

\[\{\Theta_{\lambda} = (1- \lambda) \Theta + \lambda \tilde{\Theta}: \lambda \in [0, 1]\},
\]

考量模型在\(\Theta_{\lambda}\)下的表现.



上图, 左为DomainNET real, 右为quickdraw, 可见预训练模型之间的loss landscape是很光滑的, 不同于RI-T.

module criticality

如果我们将训练好后的模型的某一层参数替换为其初始参数, 然后观察替换前后的正确率就能一定程度上判断这个层在整个网络中的重要性, module criticality就是一个这样的类似的指标.

下图反映了不同模型的不同层的criticality.

下图反映了RI-T的训练后的参数\(\theta\)其实加了扰动反而性能更好? 而P-T的就相当稳定.

pre-trained checkpoint

我们选pre-trained模型的时候, 往往是通过正确率指标来判断的, 但是事实上, 这个判断并不十分准确, 事实上我们可以早一步地选取checkpoint (直观上理解, 大概是只要参数进入了那个光滑的盆地就行了).

What is being transferred in transfer learning?的更多相关文章

  1. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  2. 迁移学习( Transfer Learning )

    在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...

  3. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

  4. 迁移学习(Transfer Learning)(转载)

    原文地址:http://blog.csdn.net/miscclp/article/details/6339456 在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型 ...

  5. Transfer learning across two sentiment classes using deep learning

    用深度学习的跨情感分类的迁移学习 情感分析主要用于预测人们在自然语言中表达的思想和情感. 摘要部分:two types of sentiment:sentiment polarity and poli ...

  6. 读论文系列:Deep transfer learning person re-identification

    读论文系列:Deep transfer learning person re-identification arxiv 2016 by Mengyue Geng, Yaowei Wang, Tao X ...

  7. 迁移学习-Transfer Learning

    迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...

  8. CVPR2018: Unsupervised Cross-dataset Person Re-identification by Transfer Learning of Spatio-temporal Patterns

    论文可以在arxiv下载,老板一作,本人二作,也是我们实验室第一篇CCF A类论文,这个方法我们称为TFusion. 代码:https://github.com/ahangchen/TFusion 解 ...

  9. pytorch例子学习——TRANSFER LEARNING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 以下是两种主要的迁移学习场景 微调convnet : ...

随机推荐

  1. 备忘录:关于.net程序连接Oracle数据库

    目录 关于使用MSSM访问Oracle数据库 关于. net 程序中连接Oracle数据库 志铭-2021年12月7日 21:22:15 关于使用MSSM访问Oracle数据库 安装访问接口组件:Or ...

  2. add more

    # -*- coding: utf-8 -*- print('123', 123) print(type('123'), type(123)) # string, integer /ˈintidʒə/ ...

  3. java中的collection小结

    Collection 来源于Java.util包,是非常实用常用的数据结构!!!!!字面意思就是容器.具体的继承实现关系如下图,先整体有个印象,再依次介绍各个部分的方法,注意事项,以及应用场景.   ...

  4. 在Eclipse中运行OSGI工程出错的解决方案

    今天学习OSGI的过程中按照书上所述搭建好第一个helloworld插件工程,运行的过程中出现下面所示的错误: !SESSION 2014-06-09 21:04:49.038 ----------- ...

  5. Tomcat(2):配置Tomcat

    1,打开IDEA创建一个项目 2,配置Tomcat服务器 3,运行 5,成功 t t

  6. Java实现邮件收发

    一. 准备工作 1. 传输协议 SMTP协议-->发送邮件: 我们通常把处理用户smtp请求(邮件发送请求)的服务器称之为SMTP服务器(邮件发送服务器) POP3协议-->接收邮件: 我 ...

  7. springboot整合jetty

    1.jetty介绍 通常我们进行Java Web项目开发,必须要选择一种服务器来部署并运行Java应用程序,Tomcat和Jetty作为目前全球范围内最著名的两款开源servlet容器,该怎么选呢. ...

  8. matplotlib画3d图

    import numpy as npimport matplotlib.pyplot as pltfrom mpl_toolkits.mplot3d import Axes3D fig = plt.f ...

  9. 使用匿名内部类和lamda的方式创建线程

    1.匿名内部类的方式 1 /** 2 *匿名内部类的方式启动线程 3 */ 4 public class T2 { 5 public static void main(String[] args) { ...

  10. Java SPI机制,你了解过吗?

    Life moves pretty fast,if you don't stop and look around once in a while,you will miss it 为什么需要SPI? ...