CLIP改进工作串讲(上)学习笔记
看了跟李沐学AI系列朱毅老师讲的CLIP改进工作串讲,这里记录一下。
1.分割
分割的任务其实跟分类很像,其实就是把图片上的分类变成像素级别上的分类,但是往往图片上能用的技术都能用到像素级别上来。所以分割的论文很多。
1.1.LSeg(language-driven semantic segmentation)
跟CLIP非常像,只不过图片编码器得到的是dense feature,其实去掉上半部分的文本操作,下半部分跟正常的有监督语义分割没有什么区别,都是图片进入分割模型,得到特征图,维度H×W×C,经过upscaling得到输出,然后和ground-truth做一个交叉熵损失,细致来说,图像的编码器是DPT的结构,也就是前面一个ViT,后面一个decoder,decoder的作用就是把瓶颈特征Upscale上去,这是视觉分支,文本分支将n个标签通过文本编码器,得到n个文本特征,维度为N×C,这里n是可以随意改变的,将视觉特征和文本特征相乘就能得到一个H×W×N的一个tensor,这也就跟传统的分割模型没什么区别了。这篇论文虽然说是zero shot,但是其实是有监督的训练,这篇论文的意义就是把文本的分支加到分割的pipeline中了,将文本特征和视觉特征结合起来,使得模型可以学习到一些language aware的视觉特征,从而使得可以通过文本的prompt去任意得到想要的分割效果。
这里的文本编码器与CLIP的一模一样,甚至在训练的时候也没有动,从头到尾都是冻住的,因为实验用的数据量较小,不过视觉编码器换了效果更好,原因作者没有给出。论文中还加了一个Spatial Reguarizationi Blocks,但是感觉像是一个trick。
1.2.GroupViT(semantic segmentation emerges from text supervision)
LSeg虽然使用了CLIP,但是其实依旧是依靠语义mask作为监督信号,而不是文本,而语义mask是很贵的,这篇论文在使用文本作为监督信号方向是一个比较有贡献的工作,为什么要叫group ViT呢?其实在视觉这边,很早之前做无监督分割的时候,经常用的一类方法就是grouping,类似于如果有一些聚类的中心点,然后从这个点开始发散,把附近周围相似的点扩充成一个group,这个group就相当于语义mask,是一种自上而下的方式,论文作者认为能把这个方法用到当前这个框架中,他们提出了一个计算单元,也就是右边这一部分,叫做grouping block,还有一些可以学习的group tokens,该方法的目的就是想让这个模型在初始学习的时候就能慢慢地把相邻相近的像素点group起来变成一个又一个的语义mask,从图中可以看出,在模型刚开始的浅层,分割的效果还不是很好,但是经过学习,到深层的时候效果已经十分不错了。
对于图像编码器来说,其实就是vision transformer,一共有12层,也就是有12个transformer layers,图像编码器的输入其实有两个部分,一个是来自原始图像的batch embedding,也就是线性投影层的输出,另外一个就是这篇论文提出来的可学习的group tokens,也就是更右边一点的彩色矩形,这里的group token可以理解为之前的cls token,就是说用这个token去代表整个图片,但是这里之所以是64而不是1,是因为之前为1的时候是代表整个图片有一个特征,但是现在是做分割,所以是每个类别或每个小块都有一个特征,不过两个token的学习过程都是一样的,都是通过transformer layer里的自注意力去学习这些图像的patch属于哪个token,经过6个transformer layer之后,使用一个grouping block进行cluster,将图像patch embedding分配到到64个group token,相当于做了一次聚类的分配,由于有64个聚类中心,所以最后剩下64个token,grouping block另一个好处是变相地把序列长度降低了,计算复杂度和计算时间也就相应减少了。
grouping block见右图,首先用类似于自注意力的方式先算一个相似度矩阵,通过这个矩阵去帮助image token做一个聚类中心的分配,从而实现降维,不过分配聚类中心这个过程是不可导的,所以用了一个trick,也就是gumbel softmax,从而把这个过程变成可导的。
由于一般的数据集或图片里的种类也不会太多,所以作者希望能把64个聚类中心变得更少,所以加了新的8个group token,通过transformer的学习将64个语义token再次映射到8个聚类中心上,这里用的是3个transformer layer,然后再通过grouping block,最终得到一个8×384的token,也就是图像被分成8大块,每一个块有一个特征。
训练过程与CLIP类似,通过图像文本对去算一个对比学习的loss,从而训练整个网络,但是这里有一个问题,CLIP中一个图片是一个特征,文本也是一个特征,但是现在文本是一个特征,但是图像有8大块的特征,所以作者通过一个平均池化融合8大块特征,再通过一层MLP得到整个图像的特征,接下来就跟CLIP完全一致。
在推理的时候,文本编码器每一类生成一个特征,图片生成8个group embedding,进行对比,由于只有8个group embedding,所以图片只能检测到8类。
论文汇总也提到了模型的两个局限性,一个是现在group ViT的结构还是偏向于一个图像的编码器,没有很好的利用dense prediction的特性,另外一个就是背景类,作者是通过一个阈值来判断是前景还是背景,当数据集类别数很多或者背景干扰较大的时候,大部分的分类分数可能不是很高,这时候阈值的设置就会对结果产生影响,作者通过实验发现,其实模型的分割已经做得很好了,但是分类错误很多,这是由于CLIP的训练方式只能学习到物体语义信息非常明确的信息,对于背景类的学习能力较差。
1.3.总结
Lseg直接使用了CLIP的预训练模型,而且使用了CLIP的大概框架,从而吧图像和文本的特征融合在一起,能够去做language-driven的分割,但是还是有监督的学习,group ViT没有使用CLIP的预训练参数,而是自己从头训练了一个分割模型,但是用了CLIP的目标函数,后面可以发现,CLIP出来以后,大家一般刚开始就是先用一下CLIP的预训练参数,做一些简单的改动,然后再把CLIP的特性与下游任务的特性结合起来,要么是利用CLIP的目标函数,要么是利用一些其他的特性。
2.目标检测
2.1ViLD(open-vocabulary object detection via vision and language knowledge distillation)
作者希望在不对Base categories数据集进行额外标注的情况下,实现对novel categories的分类
a为有监督baseline的方法,b,c,d是ViLD的网络结构。
其实这个baseline就是一个mask-RCNN,是一个两阶段的分类器,第一个阶段会出n个proposal,第二个阶段根据这N个proposal经过detection head得到一些region embedding,再通过一个分类头,得到抽取到的bounding box的类别,这样就完成了目标检测。
目标检测从目标函数来看可以分为两块,一个是怎么定位,一个是怎么分类,定位就是bounding box画得准不准,分类就是bounding box里面的物体判断得准不准,这篇论文有点把这两块解耦开来的意思,这里所有的框架图都是从第二个阶段开始,输入都是N个proposal,第一阶段没有画。
首先看一下ViLD-text,想做zero-shot的目标检测,肯定得跟文本联合起来,那怎么把文本加进来呢?最简单的方式就是像CLIP一样,先用一个图像的backbone去抽一些图像的特征,再去找一个文本的网络抽一些文本的特征,最后把这两个特征做个点乘,算一下他们的相似度就可以了,这里也一样,作者采取的就是这种最简单的方式,N个proposal进入检测头,经过一些操作之后得到N个region embedding,也就跟之前的基线网络的N个region embedding差不多,接下来就是算一些文本的embedding,文本的embedding其实就是把物体的这些类别拿过来,然后给一些prompt,然后生成一个句子,然后把这个句子扔给任何一个文本编码器就可以了,这里要注意的是文本来自于物体的类别,也就是做得还是有监督的学习,而且类别还是base category,就是数据集里的基础类,这里baseline和ViLD-text都是在同样的数据集上做有监督的训练,而且是在基础类上训练,所以在这个阶段,ViLD-text其实只是把文本特征和图像特征联系到了一起,但是zero-shot的性能还有待加强。图中text embedding标成了蓝色,也就是说这里的模型参数一直都是锁住的,并没有参与训练,跟Lseg一样,文本端是锁住的。
一旦有了图像特征和文本特征,就可以直接做一个点乘,相似度就可以当成最后分类的那个logits,就可以去做交叉熵损失,进行模型的训练了。
这里需要注意的是back ground,因为现在做的是有监督的训练,用的都是基础类,那不在这些基础类中的所有别的类别就只能全部塞给这个背景类了,所以这个背景类的学习非常的关键,需要专门去学习一个背景的embedding,至于具体的计算,背景的embedding和text embedding一样,都是去跟图像embedding做点乘就可以了。
这里光是ViLD-text做zero-shot的性能还不是很好,毕竟这里只是在基础类上进行训练的,那如何拓展到新的类别呢?或者说如何把CLIP引入到框架里来呢?接下来作者就提出了ViLD-image,其实这里想法也很简单,就是说CLIP预训练的图像编码器非常好,而且跟文本的关联也做的很好,所以作者就希望模型输出的图像embedding与CLIP输出的图像embedding尽可能一致,而做到这一点的最好方式就是知识蒸馏。
具体来说就是当有一些抽好的proposal,也就是得到的那些bounding box,就可以把它抠出来做一些resize的操作,然后就可以把它扔给CLIP预训练好的图像编码器,然后就可以得到图像的特征了,这里预训练好的图像编码器也是锁住的,就可以保证抽出来的特征跟CLIP的一样好,然后把这一个分支,也就是ViLD-image右边这个分支作为teacher网络,student网络就是之前用的目标检测的框架,就是先有一些proposal,过检测头,然后抽一些特征,作者希望这里的特征跟CLIP抽出来的特征尽可能地接近,这里直接用一个L1 loss做一个蒸馏就可以了。这里值得一提的是,现在监督信号不再是人工标注了,而是CLIP带来的图像编码,所以就不受基础类的限制了,所以抽出来的proposal既可以有基础类里来的proposal,也可以有新类里来的proposal。
不过这里有一个小弊端,就是之前这个ViLD-text做的时候,都是N个proposal,最后得到的也是N个embedding,而到了ViLD-image就变成了M pre-computed proposals,而且最后都是M个embedding了,这主要是为了让训练变得更加地快速,理论上是可以都用N个proposal的,但是事实上,如果每次都在模型训练的时候再去抽CLIP的特征就太慢了,因为这里想要一个比较大的CLIP模型,如果做一次前向过程是非常贵的,如果是N个proposal则每个iteration都需要前向N次,计算代价太高。所以作者这里去了一个折中,就是提前先把每张图片利用训练好的rpn先去预抽取M个proposal,而且embedding也抽好了,训练的时候只需要把embedding load进来就行了。不过这就跟N个proposal不一样了,N个是可以随时改变的。
最后的框架d就是两个框架的合体,左边是目标检测的分支,右边是CLIP的图像embedding分支,而且右边只有在训练的时候才用,为了计算上的简单,作者把N个propo和M个pre-computed proposal全都一起给目标头,然后得到N+M个embedding,然后在劈开,N个embedding去算交叉熵损失,M个pre-computed proposal去算蒸馏的L1 loss。
虽然实验中模型比基线分数高了几个点,但是其实是利用了数据集的特性,论文中主要是看了LVIS数据集的尾部部分,这个数据集是一个尾部非常长的长尾数据集,后面的很多类标注非常少,训练其实也效果不大,可能还不如不训练,看数据集的常见部分其实和baseline还是有差距。
2.2glip(grounded language-image pre-training)
这篇论文提出的动机还是想去利用更多的数据,作者发现在视觉语言的下游任务中,还有一类任务叫做视觉grounding,就是给一句话,把这句话中的物体在图片中定位出来,这其实跟目标检测差不多,只不过目标检测是在图片中把bounding box找出来,而视觉grounding是根据文本在图片中找,所以作者就想把这两个任务结合起来,这样就能用大量的图像文本对去训练模型。
最后的zero-shot性能还是很强的,推理过程具体来说,就是给一些物体的标签,把这些标签变成一句话,然后把这句话扔给GLIP模型,然后就把新的类别都检测出来了,或者像视觉grounding一样,给一句话,把这句话中的物体检测出来。
那么GLIP是怎么把这两个任务结合起来的呢?作者先做了一个背景介绍,目标检测的loss一般是一个分类的loss+定位的loss,而这两个任务的定位部分其实都差不多,主要是根据模型是什么来选择怎么生成定位框,区别主要在怎么算分类的loss,因为对于目标检测来说,它的标签是一个或者两个单词,是one-hot标签,但是对于视觉grounding来说标签是一个句子。
然后作者简单介绍了一些两个任务的分类loss是怎么算的。
对于目标检测来说,给定一个图片,输入图像的backbone,得到region embedding,是N×d的,就是有N个bounding box,每个bounding box的维度是d,然后输入一个分类头,得到每一个bounding box中的物体属于什么类,这个分类头就是一个矩阵,维度为c×d,将region embedding和分类头相乘就能得到分类的logits,然后用nms把这些bounding box筛选一下,再去跟ground truth算交叉熵损失就能得到最终的分类loss。
而对于视觉定位来说,其实是算了一个匹配的分数,就是想看看图像中的区域和句子中的单词是怎么匹配上的,图像这边还是一样,有一个图像backbone得到一些region特征,但是接下来就不是一个分类头了,而是像ViLD一样换成了一个文本编码器,给定一个句子,通过文本编码器就能得到文本的embedding,然后把文本的embedding和图像的embedding算一下相似度,就能得到最终的匹配分数,其实就跟ViLD-text分支是一模一样的。
作者发现其实这两种方式差不多,只需要做点小小的改动就能把这两个任务联合起来了,这个改动就是判断一下,什么时候算是一个positive match,什么时候算是一个negative match。然后作者把目标检测换成了统一之后的框架重新验证了一下在COCO上的分数,发现是完全匹配的,也就是说GLIP是完全可以迁移到任何一个目标检测数据集上的。
然后作者把目标检测和视觉grounding的数据集都合并到一起,视觉grounding数据集的bounding box是通过伪标签算的,最后得到一个非常大的预训练数据集。
模型总体框架如下,与CLIP差不多,给定一张图片和一些文本,图片输入图像编码器得到一些region embedding,文本输入文本编码器,得到一些文本的embedding,直接看最后的loss的话,由于这里其实还是有监督的学习,所以O和P点乘之后就可以去跟ground truth算一个对齐损失,这样就完成了图像和文本特征之间的融合,再看中间的fusion部分,理论上经过图像和文本编码器之后是可以直接算相似度矩阵的,但是这样图像文本的joint embedding space还没有学得很好,所以在中间多加一些层数,这里其实就是用cross attention把图像和文本特征交互了一下,这里的fusion其实也可以加到group vit中。
CLIP改进工作串讲(上)学习笔记的更多相关文章
- 6-C++远征之封装篇[上]-学习笔记
C++远征之封装篇(上) 课程简介 类(抽象概念),对象(真实具体) 配角: 数据成员和成员函数(构成了精彩而完整的类) 构造函数 & 析构函数(描述了对象的生生死死) 对象复制和对象赋值 ( ...
- 《java核心技术36讲》学习笔记-------杨晓峰(极客时间)
非常荣幸作为晓峰哥的同事,之前就看过这篇文章,重写读一遍,再学习学习. 一.开篇词 初级.中级:java和计算机科学基础.开源框架的使用:高级.专家:java io/nio.并发.虚拟机.底层源码.分 ...
- 《MySQL实战45讲》学习笔记3——InnoDB为什么采用B+树结构实现索引
索引的作用是提高查询效率,其实现方式有很多种,常见的索引模型有哈希表.有序列表.搜索树等. 哈希表 一种以key-value键值对的方式存储数据的结构,通过指定的key可以找到对应的value. 哈希 ...
- 《MySQL实战45讲》学习笔记2——MySQL的日志系统
一.日志类型 逻辑日志:存储了逻辑SQL修改语句 物理日志:存储了数据被修改的值 二.binlog 1.定义 binlog 是 MySQL 的逻辑日志,也叫二进制日志.归档日志,由 MySQL Ser ...
- 《MySQL实战45讲》学习笔记1——MySQL的基础架构
在<极客时间>订阅了<MySQL实战45讲>专栏,总觉得看完和没看一样
- java核心技术卷上学习笔记
9月5日 学习章节:第二章 Java程序设计环境 学习包括Java的安装.命令行工具.IDE.图形化开发环境等. 9月6日 学习章节:第三章 Java的基本程序设计结构 学习包括注释.数据类型.变量. ...
- 《MySQL实战45讲》学习笔记4——MySQL中InnoDB的索引
索引是在存储引擎层实现的,且在 MySQL 不同存储引擎中的实现也不同,本篇文章介绍的是 MySQL 的 InnoDB 的索引. 下文将以这张表为例开展. # 创建一个主键为 id 的表,表中有字段 ...
- 深挖计算机基础:MySQL实战45讲学习笔记
参考极客时间专栏<MySQL实战45讲>学习笔记 一.基础篇(8讲) MySQL实战45讲学习笔记:第一讲 MySQL实战45讲学习笔记:第二讲 MySQL实战45讲学习笔记:第三讲 My ...
- SpringBoot学习笔记 - 构建、简化原理、快速启动、配置文件与多环境配置、技术整合案例
[前置内容]Spring 学习笔记全系列传送门: Spring学习笔记 - 第一章 - IoC(控制反转).IoC容器.Bean的实例化与生命周期.DI(依赖注入) Spring学习笔记 - 第二章 ...
- HTML&CSS基础学习笔记1.16-单元格间距和表格主体
上一篇讲html学习笔记,讲过了合并单元格,那么今天就来介绍下如何控制单元格的间距,以及表格主体的相关知识. 单元格间距 在上个知识点的显示结果中你可能发现了,单元格与单元格之间有一小段空白.这是由& ...
随机推荐
- java的饿汉和懒汉设计模式
本文主要讲述java的饿汉和懒汉设计模式 饿汉和懒汉设计模式的目的:使得该类的对象,只能有一个,不允许其他类,创建该类的对象. 饿汉设计模式 示例代码如下: 1 public class Hunger ...
- mybatis 之定义拦截器 控制台SQL的打印
类型 先说明Mybatis中可以被拦截的类型具体有以下四种: 1.Executor:拦截执行器的方法.2.ParameterHandler:拦截参数的处理.3.ResultHandler:拦截结果集的 ...
- java线程基础知识整理
目录 线程基本概念 1.java实现线程 2.线程的生命周期 3.线程常用的方法 3.1.sleep() 3.2.interrupt方法 3.3.stop方法 4.线程调度 4.1.常见的线程调度模型 ...
- MasaFramework -- i18n (国际化)
概念 作为一个普通开发者, 我们负责的项目的使用群体大多数是本国的人民, 但不可避免的也有一些做外贸的业务或者给外企做的项目, 这个时候就要求我们的项目有服务全球客户的能力, 而一个支持国际化能力的框 ...
- SQLSERVER 居然也能调 C# 代码 ?
一:背景 1. 讲故事 前些天看到一个奇怪的 Function 函数,调用的是 C# 链接库中的一个 UserLogin 方法,参考代码如下: CREATE FUNCTION dbo.clr_User ...
- [机器学习] Yellowbrick使用笔记1-快速入门
Yellowbrick是一个机器学习可视化库,主要依赖于sklearn机器学习库,能够提供多种机器学习算法的可视化,主要包括特征可视化,分类可视化,回归可视化,回归可视化,聚类可视化,模型选择可视化, ...
- [R语言] ggplot2入门笔记2—通用教程ggplot2简介
文章目录 通用教程简介(Introduction To ggplot2) 2 ggplot2入门笔记2-通用教程ggplot2简介 1. 了解ggplot语法(Understanding the gg ...
- [IOI2016] shortcut
有显然的 \(O(n^3)\) 做法,可以获得 \(38pts\).(退火在洛谷上能跑 \(75pts\)) 答案具有单调性,考虑二分一个 \(M\) 并判断.列出 \(i\) 到 \(j\) 的距离 ...
- 《Effective C++》实现 章节
Item26:尽可能延后变量定义式的出现时间 Item27:尽量少做转型动作 关于这一点,专门开了一个新的总结: http://blog.csdn.net/m0_37316917/article/de ...
- 使用IIS配置代理,转发POST和GET访问,配置IIS接口转发失效问题处理
先说一下可能引发配置失败的原因:大概率是你的Application Request Routing没有配置好,或者你的正则表达没有搞好,往下看步骤自己对照哇~ 1.确保服务器已经安装IIS 2.下载U ...