《CNN Image Retrieval in PyTorch: Training and evaluati-ng CNNs for Image Retrieval in PyTorch》代码思路解读
这是一个基于微调卷积神经网络的图像检索的代码实现,这里我就基于代码做一个实现思路的个人解读,如果有不对的地方或者不够详细的地方,欢迎大家指出。
代码的GitHub地址:filipradenovic/cnnimageretrieval-pytorch (Commit c340540)
相关论文地址:
Fine-tuning CNN Image Retrieval with No Human Annotation, Radenović F., Tolias G., Chum O., TPAMI 2018 [arXiv]
CNN Image Retrieval Learns from BoW: Unsupervised Fine-Tuning with Hard Examples, Radenović F., Tolias G., Chum O., ECCV 2016 [arXiv]
写在前面
我是在2020年的4月份,为了我的本科毕设,研读学习这份代码。但到现在才有空闲来把我学习的成果整理到博客上,时间上是有些延迟的。作者在2020年12月份又更新了代码,发布了新的版本(V1.2),添加了亿点点细节。但这不影响我发表这篇博客,因为这个代码的整体思路还是不变的,这也是这篇博客的重点。如果这篇博客涉及到了代码的细节,那就是基于它V1.1的版本,更准确地说,是基于Commit c340540的版本。
关于我为什么是基于代码的解读,而不是基于论文的解读。原因是我做本科的毕设时,阅读论文的水平有限,把握不住实现的思路与过程。求助导师时,他说:可以阅读代码,阅读代码比阅读论文更容易。在此,也感谢导师的指导。
再次说明:对于数据集的一些更精妙的设计,以及对图像的白化处理等,这里由于水平的限制,都避开不谈。想要了解的同学,可以深入阅读论文,这篇博客并没有对这些内容进行解读。这篇博客只是梳理了非常基本的代码思路。
温馨提示:这篇博客有点长,点击右下角魔法阵上的【显示目录】,可以更方便导航。
前置知识
想要了解这个检索实现的思路,需要先了解一些相关的知识,有助于后续整体思路的把握。由于篇幅的关系,这个章节只会对这些知识做简要的介绍。如果已经了解,可跳过这一章节。这些前置知识有:孪生网络,Contrastive Loss,Triple Loss
孪生网络
简单来说,孪生网络可用来衡量两个输入的相似度。孪生网络的结构如下图所示,两个输入经过完全一样的神经网络,输出为各自的高维向量表征。得到输入的向量特征后,可以通过余弦相似度计算两个输入的相似度,也可以计算特征间的距离,如欧氏距离,通过loss计算,来评价模型的特征表示效果。
既然提到loss方程,接下来就要介绍两个ranking loss,用于学习相对距离,相关关系,常被用于孪生网络中。
图1 孪生网络结构
Contrastive Loss
中文称为【对比损失函数】,表达式如下所示。其中,d表示两个向量的距离,例如一般是欧氏距离;y表示两个输入是否相似,如果相似则为1,如果不相似为0;margin是设定好的阈值,当两个样本的向量距离超过一定值,也就是margin,就表示这两个样本不相似了。
$ L=\frac{1}{2 N} \sum_{n=1}^{N} y d^{2}+(1-y) \max (\operatorname{margin}-d, 0)^{2}$
从式子上我们可以发现,如果两个输入相似(即y=1),则式中只剩下 $ d^{2} $。这符合我们的直观感受:如果两个输入相似,向量的距离越大,则损失越大。如果两个输入不相似(即y=0),则式中只剩下 $ \max (\operatorname{margin}-d, 0)^{2} $。这里应该理解为:当两个输入不相似时,若向量的距离大于margin,则损失为0;若向量的距离小于margin,且距离越小,损失越大。于是优化的方向为让相似样本的向量特征距离变小,让不相似样本的向量特征距离超过阈值。
这里有张经典的图,如下图所示,红色虚线为相似时的曲线,蓝色实线为不相似情况的曲线,横坐标为样本间的特征距离,横坐标上有个特殊的点是margin值,纵坐标是损失值。从图上我们也发现,相似情况下(即红色曲线)损失值随距离的增大而增大;不相似的情况下(即蓝色曲线)损失值随距离的缩小而减小,且距离大于等于margin时,损失为0。(这两个曲线圈起来,后面要考的)
图2 Contrastive Loss 损失与距离的关系图
Triple Loss
中文称为【三元损失函数】,顾名思义,计算一次loss要同时输入三元:锚点样本(anchor),正样本(positive),负样本(negative),分别用$a,p,n$表示,损失函数的表达式如下所示。其中,$r_{a}, r_{p}, r_{n}$分别表示锚点样本,正样本,负样本的高维向量表征。而$\mathrm{d}(r_{a}, r_{p}), \mathrm{d}(r_{a}, r_{n}) $表示$<a,p>$之间的距离和$<a,n>$之间的距离。同样的,这里的m也是margin,是设定好的阈值,表示希望的$\mathrm{d}(r_{a}, r_{n})$ 与 $\mathrm{d}(r_{a}, r_{p}) $的差距。
$ L\left(r_{a}, r_{p}, r_{n}\right)=\max \left(0, m+\mathrm{d}\left(r_{a}, r_{p}\right)-\mathrm{d}\left(r_{a}, r_{n}\right)\right) $
从上面的式子中我们可以发现,当$\mathrm{d}(r_{a}, r_{n}) $比$\mathrm{d}(r_{a}, r_{p})$大,损失就小;且差距越大,损失越小;直到差值大于margin,损失为0。相反,如果差距越小,损失越大;甚至$\mathrm{d}(r_{a}, r_{p})$比$\mathrm{d}(r_{a}, r_{n}) $还大时,损失值就很大了。
整体思路
训练过程
了解了前面介绍的前置知识之后,再看图像检索的训练过程,就会理解他的用意。
下图是作者在GitHub和论文中都展示的图片,很好地表示了其核心思想。这里我对图中的文字做了本土化,如果想看原本的表达,再次指路GitHub或者论文。示意图中的$ \overline{\mathbf{f}}( ) $表示图像的高维特征向量。该图中的上半部分描述了原始图像转化为高维向量特征的过程:图像经过卷积层(也即卷积神经网络,如ResNet等去掉最后一层【全连接层】),再经过池化层和$\ell_{2}$归一化操作(即向量单位化),最终形成一个图像的固定维度的向量表示。示意图的下半部分描述了训练时,使用对比损失函数的情况。示意图中的两条Loss-dist曲线就是图2中的拆分。
图3 图像检索训练示意图
检索过程
对于检索的过程,我自己画了如下图所示的示意图。检索过程如下:
- 图片池里的图片转换为列向量特征,多个列向量特征再拼在一起组成矩阵;
- 将查询对象转换为列向量特征,如果有多个查询对象同时查询,则将它们的列向量特征拼成矩阵。
- 将图片池的特征矩阵转置后与查询对象的向量特征(即计算余弦相似度)得到相似度的结果。这个结果中第i行,第j列元素表示的是第i个图片池中的图片与第j个查询对象的相似度
图4 图像检索过程示意图
欧氏距离与余弦相似度
看到这里,不知道大家有没有疑问:训练过程,loss方程用的是Contrastive Loss或者Triple Loss 本质都是让相似样本的距离更近,不相似样本的距离更远。这里的距离用的是欧氏距离。但实际检索时,不是用样本间的欧氏距离排名,而是用余弦相似度排名。诚然,余弦相似度计算更简单,只要矩阵乘法运算,不需要像欧氏距离一样计算平方。但是,从理论上来说,这样是可行的吗?难道向量间的欧氏距离越近,余弦相似度越高?如果没有这个疑问的小伙伴就跳过这part吧!
回到我们的疑问:难道向量间的欧氏距离越近,余弦相似度越高?这当然不是绝对的,我们可以很轻松举出反例。那难度作者错了吗?并没有,这个结论在一定条件下是可以成立的,那就是当向量的模长一定时,这个结论是成立的!而作者早在图3中,就保证了这关键的一步,是的,就是$\ell_{2}$归一化操作(即向量单位化)。这些向量特征都是单位向量!
下面是以上结论的一个证明:其实,向量间的欧氏距离和余弦相似度由余弦公式建立联系的,设两个向量分别为 $\boldsymbol{a}$, $\boldsymbol{b}$ ,则有以下的关系。其中,$ d_{<\boldsymbol{a}, \boldsymbol{b}>} $ 表示两个向量间的欧氏距离。由这个公式我们可以得知,两个向量模长一定时,欧氏距离越近,余弦相似度越高。
$ d_{<\boldsymbol{a}, \boldsymbol{b}>}^{2} = \left |\boldsymbol{a} \right |^{2} + \left |\boldsymbol{b} \right |^{2} - 2\boldsymbol{a}\cdot \boldsymbol{b} $
数据集介绍
作者使用了 retrieval-SfM-120k 作为训练集和验证集。使用Oxford5k,Paris6k,ROxford5k,RParis6k作为测试集。接下来对这几个数据集做个介绍,会涉及到具体的文件细节。
retrieval-SfM-120k
retrieval-SfM-120k是若干张建筑物的图片,它们已经分好了簇,相似图片在一个簇里,不同簇的图片即为不相似。还有若干对q-p图片,q表示查询,p表示相似的图片。
文档结构
retrieval-SfM-120k 下载解压后目录结构包含ims文件夹,retrieval-SfM-120k.pkl 和retrieval-SfM-120k-whiten.pkl。其中 ims 文件夹 存放图片 。 retrieval-SfM-120k.pkl 存 放 图 片相关信息。retrieval-SfM-120k-whiten.pkl 的内容还不了解,这里不做解释,不影响代码整体的理解。
ims 下还有三级目录。其中图片的文件名(代码中称为 cid)的逆序的前六位决定了图片的路径,如某图的 cid 逆序前六位为 123456,则其路径为./ims/12/34/56。
retrieval-SfM-120k.pkl 包含了这个 database 的相关信息。其字典结构图如图 5 所示。其中按mode分为train和val。train和val又是两个字典,分别包含cids,cluster,qidxs, pidxs这四个关键字,这四个key对应的value都是列表。其中 cids 是所有图片的文件名列表;cluster 是由 cids 中对应图片的 clusterID 组成的列表;qidxs 和 pidxs 是一一对应的,组成一对对 q-p对,其中 q 表示查询对象在 cids 中的下标,p 表示与查询对象匹配的图像在 cids的下标。而图中的数字为该列表的元素个数。
图5 retrieval-SfM-120k.pkl字典结构
Oxford5k, Paris6k,ROxford5k,RParis6k
Oxford5k 由 5062 张已被人工标注的图片组成,其中有 55 张是查询对象,是11 个地标的五种不同拍摄条件下的图片。类似地,Paris6k 数据集由 6412 个图像和 55 个查询组成。而ROxford5k和RParis6k是Radenovi 等人,重新整理了Oxford5k,Paris6k这两个数据集而成的。所做的改动如下:每个数据集新增了 15 个查询;修改标注错误和数据集大小;还根据答案集的不同,设置了挑战级别:Easy,Medium,Hard。
文档结构
代码中会自动下载对应的测试集,这四个数据集下载后的文档结构比较像,以Oxford5k为例,有一个名为jpg的文件夹和一个gnd_oxford5k.pkl的文件。
其中,jpg的文件夹里面存放的就是图片,gnd_oxford5k.pkl文件存放的是图片的检索信息。
接下来具体介绍各个gnd_xxx.pkl文件,其中xxx是对应的数据集名称,如oxford5k,paris6k等。这些gnd_xxx.pkl文件存放一个了dict,其字典结构如图6所示。其中,imlist 和 qimlist 都是列表,保存着每张图片或查询对象的文件名,下面的数字是列表长度。而 gnd 也是列表,和 qimlist 列表是一一对应的,保存了对应的查询图像的检索信息。而每个gnd元素都是字典,都包含若干项关键字,包含的关键字如图6所示。其中 bbx 内有 4 个元素,为int类型:x1,y1,x2,y2, 表示了查询图片的具体查询区域。而ok,junk是与查询对象匹配与不匹配的图片在 imlist 的索引列表,easy,hard,junk也是针对查询对象的匹配程度分出来的图片在 imlist 的索引列表,根据这三个列表,可以把检索分为三个难度:Easy,Medium,Hard。
图6 gnd_xxx.pkl 字典结构
正负集划分
如果是 Oxford5k 和 Paris6k,正类就是ok列表,负类就是junk列表。而对于 ROxford5k 和 RParis6k,会分成三个难度:Easy(E),Medium(M),Hard(H),这三个难度下的正负类划分不太一样,具体见表 1。
E | M | H | |
正类 | easy | easy, hard | hard |
负类 | hard, junk | junk | junk, easy |
具体实现细节
这一章节会涉及到具体的实现细节。这一章节提到的过程,都是默认参数下的过程,减少了许多可选操作的说明,如白化操作等。
- 模型 AlexNet,Vgg16,ResNet101等经典模型去掉全连接层作为卷积层,再加上一层池化操作和$\ell_{2}$正则化操作。其中池化可以是最大值池化,平均值池化和广义平均值池化(数学上,广义平均值也就是p次幂平均)。
- 数据库的选择 训练集:retrieval-SfM-120k['train'],验证集:retrieval-SfM-120k['val'],测试集:Oxford5k, Paris6k, ROxford5k, RParis6k。
- 训练时模型的输入 训练集中的图片通过模型变成特征向量。从中选取qsize(q-p对的个数)个元组。每个元组共有(1+1+nnum)个特征向量,分别是查询对象q,正类p和nnum个负类n1,n2....查询和正类是由q-p对直接给出。负类是q由当前模型的在图片池中的查询结果,按照查询顺序从上到下依次选取nnum个与q在不同簇的图片,且这nnum个图片也在不同的簇中。(注:这些元组在不同的epoch就会更新一次,因为模型更新了。)这里呼应了论文标题中的No Human Annotation(不需要人工标注)。
- 训练时模型的输出 每个元组经过模型的向量特征组成的矩阵(矩阵维度:特征维度*(1+1+nnum) )
- 训练的标签 [-1,1,0,0,...,0]。-1代表查询对象,1表示匹配,0表示不匹配。与输入的元组的每个特征向量一一对应。
- 损失函数 如果是Contrastive Loss,每个元组的loss是nnum个负类与查询对象的Contrastive Loss 和 nnum个相同的正类与查询对象的Contrastive Loss 的和;如果是Triple Loss,每个元组的loss是nnum个(查询对象,正类,其中一个负类)的Triple Loss 的和。他们的d都是用向量的欧氏距离定义的。
- 训练的优化 采用Adam算法优化,学习率随着epoch指数衰减,公式为: $ l r=l r_{0} \times \gamma^{\text {epoch }} $
- 测试时模型的输入 测试集中图库的图片和查询对象的图片
- 测试时模型的输出 查询对象的特征矩阵(所有查询对象的特征向量组成的矩阵)和图库图片特征矩阵(图库图片所有的特征向量组成的矩阵)
- 测试的检索排名 图库图片特征矩阵与查询对象特征矩阵的点乘,得到的是scores矩阵(维度:图库图片数量* 查询数量),其中第i行,第j列表示图片池中的第i个图片与第j个查询对象的相似度得分。ranks是scores的按列排序的索引值,即得分高的图片的索引排在前面,是最终的检索结果。
- 检索的评价指标 Oxford5k 和 Paris6k 的检索结果的指标是 mAP(mean Average Precision),AP 是单个查询结果的平均准确率,mAP 是所有查询结果 AP 的平均。ROxford5k 和 RPairs6k 的检索指标比较丰富。除了mAP 用于评价整体的检索质量外,新增了 mP@k,是结果列表中 top-k 检索结果的准确率指标,反映了搜索引擎的质量。匹配的图片排的越前面得分会越高,不匹配的图片越排在匹配的后面得分会越高。
文件目录
代码的文件结构及其说明如下所示。
.
│ LICENSE
│ README.md
│
└─cirtorch
│ __init__.py
│
├─datasets 数据集加载和处理
│ datahelpers.py 图片处理方法
│ genericdataset.py 定义通过文件名列表加载图片的方法
│ testdataset.py 定义生成测试集的方法
│ traindataset.py 定义生成训练集和验证集的方法
│ __init__.py
│
├─examples 包含所有可运行的文件
│ test.py 测试模型(相比e2e版本,增加许多可选操作)
│ test_e2e.py 端到端测试模型
│ train.py 训练模型
│ __init__.py
│
├─layers 定义神经网络里的层操作方法
│ functional.py 定义以下三个文件要用的函数
│ loss.py 定义损失函数
│ normalization.py 定义正则化方法
│ pooling.py 定义池化方法
│ __init__.py
│
├─networks 定义所用的模型
│ imageretrievalnet.py 定义模型,初始化模型,定义通过模型生成特征的方法
│ __init__.py
│
└─utils 包含所有工具文件
download.py 下载各个数据集
download_win.py 未知,没怎么用到
evaluate.py 定义计算检索评价指标的方法
general.py 定义通用的工具方法
whiten.py 定义白化方法
__init__.py
后记
关于训练和测试的代码运行,GitHub上都有相关的指导,这里就不多赘述。这篇文章旨在解读代码,帮助初入门的同学快速掌握基本思想。对于数据集的一些更精妙的设计,以及对图像的白化处理等,这里由于水平的限制,都避开不谈,想要了解的同学,可以深入阅读论文,这篇博客并没有对这些内容进行解读。如果文中有错误的地方,欢迎大家私信或评论。
《CNN Image Retrieval in PyTorch: Training and evaluati-ng CNNs for Image Retrieval in PyTorch》代码思路解读的更多相关文章
- 【PyTorch深度学习60分钟快速入门 】Part1:PyTorch是什么?
0x00 PyTorch是什么? PyTorch是一个基于Python的科学计算工具包,它主要面向两种场景: 用于替代NumPy,可以使用GPU的计算力 一种深度学习研究平台,可以提供最大的灵活性 ...
- 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)
文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...
- (原)CNN中的卷积、1x1卷积及在pytorch中的验证
转载请注明处处: http://www.cnblogs.com/darkknightzh/p/9017854.html 参考网址: https://pytorch.org/docs/stable/nn ...
- ubuntu之路——day18 用pytorch完成CNN
本次作业:Andrew Ng的CNN的搭建卷积神经网络模型以及应用(1&2)作业目录参考这位博主的整理:https://blog.csdn.net/u013733326/article/det ...
- (转)Awesome PyTorch List
Awesome-Pytorch-list 2018-08-10 09:25:16 This blog is copied from: https://github.com/Epsilon-Lee/Aw ...
- 使用pytorch构建神经网络的流程以及一些问题
使用PyTorch构建神经网络十分的简单,下面是我总结的PyTorch构建神经网络的一般过程以及我在学习当中遇到的一些问题,期望对你有所帮助. PyTorch构建神经网络的一般过程 下面的程序是PyT ...
- 吐血整理:PyTorch项目代码与资源列表 | 资源下载
http://www.sohu.com/a/164171974_741733 本文收集了大量基于 PyTorch 实现的代码链接,其中有适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文 ...
- Note | PyTorch官方教程学习笔记
目录 1. 快速入门PYTORCH 1.1. 什么是PyTorch 1.1.1. 基础概念 1.1.2. 与NumPy之间的桥梁 1.2. Autograd: Automatic Differenti ...
- 转 Pytorch 教学资料
本文收集了大量PyTorch项目(备查) 转自:https://blog.csdn.net/fuckliuwenl/article/details/80554182 目录: 入门系列教程 入门实例 图 ...
随机推荐
- Windows bat批处理删除指定N天前的文件
1:新建批处理文件:del_old_file.bat,更改系统时间为7天前,在c盘sql back 目录下新建测试文件,再将系统时间改为正确时间 2:编辑内容: rem 删除C:\sql back目录 ...
- 21.Quick QML-FileDialog、FolderDialog对话框
1.FileDialog介绍 Qt Quick中的FileDialog文件对话框支持的平台有: 笔者使用的是Qt 5.8以上的版本,模块是import Qt.labs.platform 1.1. 它的 ...
- 『政善治』Postman工具 — 8、Postman中Pre-request Script的使用
目录 1.Pre-request Script介绍 2.常用SNIPPETS(片段)说明 (1)获取变量脚本: (2)设置变量脚本: (3)清空变量脚本: (4)Send a request代码片段 ...
- MFC对话框不能使用菜单更新宏ON_UPDATE_COMMAND_UI
菜单更新宏的原理 更新处理宏的工作原理是基于框架窗口类的.MFC中对话框菜单更新宏的原理是:当我们使用从CFrameWnd框架窗口类中派生的类创建窗口时,当我们单击菜单且菜单还未弹出前会产生WM_IN ...
- char值不能直接用作数组下标
#include <stdio.h> //用 char 的值作为数组下标(例如,统计字符串中每个字符出现的次数),要考虑到 //char 可能是负数.有的人考虑到了,先强制转型为 unsi ...
- Dart 2.13 版现已发布
作者 / Kevin Moore & Michael Thomsen Dart 2.13 版现已发布,其中新增了类型别名功能,这是目前用户呼声第二高的语言功能.Dart 2.13 还改进了 D ...
- 数据流分析软件SQLFlow的高阶模式Job任务介绍
SQLFlow是一个可视化的在线处理SQL对象依赖关系的工具,只需要上传你的SQL脚本,它可以自动分析SQL里的数据对象,包括database.schema.table.view.column.pro ...
- [刷题] 144 Binary Tree Preorder Traversal
要求 二叉树的前序遍历 实现 递归 栈模拟 定义结构体 Command 模拟指令,字符串s描述命令,树节点node为指令作用的节点 定义栈 Stack 存储命令 1 #include ...
- 在写脚本时,在一开始(Shebang 之后)就加上这一句,或者它的缩略版: set -xeuo pipefail
编写可靠 bash 脚本的一些技巧 腾讯技术工程 已认证的官方帐号 1,254 人赞同了该文章 写过很多 bash 脚本的人都知道,bash 的坑不是一般的多. 其实 bash 本身并不是一个 ...
- Zookeeper——Docker下安装部署
单节点安装 一. 环境说明 docker: 18.09.9-ce zookeeper: 3.5.6 二. 拉取 zookeeper 镜像 拉取镜像 docker pull zookeeper 默认是摘 ...