Deep Metric Learning via Lifted Structured Feature Embedding

CVPR 2016

  摘要:本文提出一种距离度量的方法,充分的发挥 training batches 的优势,by lifting the vector of pairwise distances within the batch to the matrix of pairwise distances. 刚开始看这个摘要,有点懵逼,不怕,后面会知道这段英文是啥意思的。

  

  引言部分开头讲了距离相似性度量的重要性,并且应用广泛。这里提到了三元组损失函数 (triplet loss),就是讲在训练的过程当中,尽可能的拉近两个相同物体之间的距离,而拉远不同物体之间的距离;这种做法会比普通的训练方法得到更好的效果。但是,文章中提到,现有的三元组方法却无法充分利用 minibatch SGD training 的 training batches 的优势。现有的方法首先随机的采样图像对或者三元组,构建训练 batches, 计算每一个 pairs or triplets 的损失。本文提出一种方法,称为:lifts,将 the vector of pairwise distances 转换成 the matrix of pairwise distance. 然后在 lifts problem 上设计了一个新的结构损失目标。结果表明,在 GoogleLeNet network 上取得了比其他方法都要好的结果。

  然后作者简单的回顾了一下关于判别性训练网络(discriminatively training networks)来学习 semantic embedding。大致结构预览图如下所示:

  首先是: Contrastive embedding. 

  这种方法在 paired data ${(x_i, x_j, y_{ij})}$上进行训练。这种 contrastive training 最小化具有相同 label 类别的样本之间的距离,然后对不同label的样本,但是其距离小于 $\alpha$ 的 negative pair 给予惩罚。代价函数的定义为:

  其中,m 代表batch中图像的个数,f(*)是网路输出的特征,即原文中表达的:the feature embedding output from the network. $D_{i, j}$ 是两个样本特征之间欧式距离的度量。标签 $y_{i, j} \in {0, 1}$表明是否样本对来自同一个类别。$[*]_+$ 操作是 the hinge function max(0, *)。

  第二个是:Triplet embedding

  这个就是著名的三元组损失函数了,即:找一个 anchor,然后找一个正样本,一个负样本。训练的目的就是:鼓励网络找到一个 embedding 使得 xa and xn 之间的距离大于 xa and xp 加上一个 margin $\alpha$ 的和。损失函数定义为:

  其中,D仍然表示样本之间特征的距离。

  然后就是本文提出的一种度量方法了:

  Deep metric learning via lifted structured feature embedding.

   我们基于训练集合的正负样本,定义了一个结构化的损失函数:  

  其中,P 是正样本的集合,N 是负样本的集合。这个函数提出了两个计算上的挑战:

  1. 非平滑(non-smooth)

  2. 评价和计算其子梯度需要最小化所有样本对若干次。

  我们以两种方式解决了上述挑战:

  首先,我们优化上述函数的一个平滑上界;

  第二,对于大数据常用的方法类似,我们采用随机的方法。

  然而,前人的工作都是用SGD的方法,随机的均匀的选择 pairs or triplets。我们的方法从这之中得到了借鉴:

    (1). it biases the sample towards including "difficult" pairs, just like a subgradient of $J_{i,j}$ would use the close negative pairs;

  (2). 一次采样就充分的利用了一个 mini-batch的全部信息,而不仅仅是两个pair之间的信息。

  为了充分的利用这个 batch,一个关键的 idea 是增强 mini-batch 的优化以利用所有的pairs。

  需要注意的是:随机采样的样本对之间的 negative edges 携带了非常有限的信息。

  

  所以,我们的方法改为并非完全随机,而是引入了重要性采样的元素。我们随机的采样了一些 positive pairs,然后添加了一些他们的 difficult neighbors 来训练 mini-batch. 这个增强增加了子梯度会用到的相关信息。下图展示了一个 positive pair 在一个 batch 中的搜索过程,即:在一个 positive pair 的图像中,我们找到其 close(hard)negative images。  

  注意到我们的方法可以从两端开始搜索,而三元组则仅仅只能和定义好的结构上的元素进行搜索。

  

   此外,搜索 single hardest negative with nested max function 实际上会导致网络收敛到一个 bad local optimum. 所以我们采用了如下的 smooth upper bound,所以 我们的损失函数定义为:  

  其中,P是batch中 positive pairs 集合,N 是negative pairs 的集合。后向传播梯度可以如算法1所示的那样,对应距离的梯度为:

  

    其中的 1[*] 是指示函数,如果括号内的判断为真,那么输出为1,否则就是0.

  本文的算法流程图,如下所示:

  


  结果展示:

  


    文章总结

  可以看出,本文是在三元组损失函数基础上的一个改进。并非仅仅考虑预先定义好的样本之间的差异性,而是考虑到一个 batches 内部 所有的样本之间的差异。在这个过程中,文章中引入了类似 hard negative mining 的思想,考虑到正负样本之间的难易程度。并且为了避免网络的训练陷入到 局部最优的bug中去,引入了损失函数的上界来缓解这个问题。

  一个看似不大的改动,却可以发到CVPR,也从某个角度说明了这个方法的价值。

  难道,三元组损失函数就这样被这个算法击败了? 自己当初看到三元组损失函数的时候,为什么就没有忘这个方向去思考呢???

  还有一个疑问是:为什么这种方法的操作,称为:lifted structured feature embedding ?

  难道说,是因为这个左右移动的搜索 hard negative samples 的过程类似于电梯(lift)?那 feature embedding 怎么理解呢? embedding 是映射,难道是:特征映射么??

  

论文笔记之: Deep Metric Learning via Lifted Structured Feature Embedding的更多相关文章

  1. 论文笔记:Deep Residual Learning

    之前提到,深度神经网络在训练中容易遇到梯度消失/爆炸的问题,这个问题产生的根源详见之前的读书笔记.在 Batch Normalization 中,我们将输入数据由激活函数的收敛区调整到梯度较大的区域, ...

  2. 读论文系列:Deep transfer learning person re-identification

    读论文系列:Deep transfer learning person re-identification arxiv 2016 by Mengyue Geng, Yaowei Wang, Tao X ...

  3. 论文笔记——A Deep Neural Network Compression Pipeline: Pruning, Quantization, Huffman Encoding

    论文<A Deep Neural Network Compression Pipeline: Pruning, Quantization, Huffman Encoding> Prunin ...

  4. 【论文阅读】Deep Mutual Learning

    文章:Deep Mutual Learning 出自CVPR2017(18年最佳学生论文) 文章链接:https://arxiv.org/abs/1706.00384 代码链接:https://git ...

  5. 论文解读《Deep Resdual Learning for Image Recognition》

    总的来说这篇论文提出了ResNet架构,让训练非常深的神经网络(NN)成为了可能. 什么是残差? "残差在数理统计中是指实际观察值与估计值(拟合值)之间的差."如果回归模型正确的话 ...

  6. Person Re-identification 系列论文笔记(二):A Discriminatively Learned CNN Embedding for Person Re-identification

    A Discriminatively Learned CNN Embedding for Person Re-identification Zheng Z, Zheng L, Yang Y. A Di ...

  7. 论文笔记:Deep feature learning with relative distance comparison for person re-identification

    这篇论文是要解决 person re-identification 的问题.所谓 person re-identification,指的是在不同的场景下识别同一个人(如下图所示).这里的难点是,由于不 ...

  8. 论文笔记:Deep Attentive Tracking via Reciprocative Learning

    Deep Attentive Tracking via Reciprocative Learning NIPS18_tracking Type:Tracking-By-Detection 本篇论文地主 ...

  9. 论文笔记 — L2-Net: Deep Learning of Discriminative Patch Descriptor in Euclidean Space

    论文: 本文主要贡献: 1.提出了一种新的采样策略,使网络在少数的epoch迭代中,接触百万量级的训练样本: 2.基于局部图像块匹配问题,强调度量描述子的相对距离: 3.在中间特征图上加入额外的监督: ...

随机推荐

  1. ElasticSearch学习问题记录——Invalid shift value in prefixCoded bytes (is encoded value really an INT?)

    最近在做一个电商项目,其中商品搜索中出现一个奇怪的现象,根据某个字段排序的时候会出现商品数量减少的情况.按照一般路要么查不出来,要么正常显示,为什么增加了按照销量排序就会出现查询结果减少的情况. 查了 ...

  2. poj2778DNA Sequence(AC自动机+矩阵乘法)

    链接 看此题前先看一下matrix67大神写的关于十个矩阵的题目中的一个,如下: 经典题目8 给定一个有向图,问从A点恰好走k步(允许重复经过边)到达B点的方案数mod p的值    把给定的图转为邻 ...

  3. Oracle的不完全恢复

    一.不完全恢复特性 1.不完全恢复 不完全恢复仅仅是将数据恢复到某一个特定的时间点或特定的SCN,而不是当前时间点.不完全恢复会影响整个数据库,需要在MOUNT状  态下进行.在不完全恢复成功之后,通 ...

  4. [http session]

    原文链接:http://lavasoft.blog.51cto.com/62575/275589/ 1.Session创建的时间是: 一个常见的误解是以为session在有客户端访问时就被创建,然而事 ...

  5. SPSS数据分析—描述性统计分析

    描述性统计分析是针对数据本身而言,用统计学指标描述其特征的分析方法,这种描述看似简单,实际上却是很多高级分析的基础工作,很多高级分析方法对于数据都有一定的假设和适用条件,这些都可以通过描述性统计分析加 ...

  6. html5-表单

    例子: text,number,email 的输入框 <!-- required:必填项 --> <!-- autofocus:获得焦点 --> <!-- placeho ...

  7. iOS开发UI篇—Date Picker和UITool Bar控件简单介绍

    iOS开发UI篇—Date Picker和UITool Bar控件简单介绍 一.Date Picker控件 1.简单介绍: Date Picker显示时间的控件 有默认宽高,不用设置数据源和代理 如何 ...

  8. iOS开发Swift篇—(二)变量和常量

    iOS开发Swift篇—(二)变量和常量 一.语言的性能 (1)根据WWDC的展示 在进行复杂对象排序时Objective-C的性能是Python的2.8倍,Swift的性能是Python的3.9倍 ...

  9. opencv基本的数据结构(转)

    DataType : 将C++数据类型转换为对应的opencv数据类型 enum { CV_8U=0, CV_8S=1, CV_16U=2, CV_16S=3, CV_32S=4, CV_32F=5, ...

  10. 【转】nginx+tomcat+memcached (msm)实现 session同步复制

    出现session不同步时,请放到content.xml中,实际验证有效: tomcat + memcached + nginx 实现session共享 这里重点强调如何实现linux服务器上 服务器 ...