最近在多个关键词(小数据集,无监督半监督,图像分割,SOTA模型)的范畴内,都看到了这样的一个概念,孪生网络,所以今天有空大概翻看了一下相关的经典论文和博文,之后做了一个简单的案例来强化理解。如果需要交流的话欢迎联系我,WX:cyx645016617

所以这个孪生网络入门,我想着分成上下两篇,上篇也就是这一篇讲解模型理论、基础知识和孪生网络独特的损失函数;下篇讲解一下如何用代码来复线一个简单的孪生网络。

1 名字的由来

孪生网络的别名就会死Siamese Net,而Siam是古代泰国的称呼,所以Siamese其实是“泰国人”的古代的称呼。 为什么Siamese现在在英文中是“孪生”“连体”的意思呢?这源自一个典故:

十九世纪泰国出生了一对连体婴儿,当时的医学技术无法使两人分离出来,于是两人顽强地生活了一生,1829年被英国商人发现,进入马戏团,在全世界各地表演,1839年他们访问美国北卡罗莱那州后来成为“玲玲马戏团” 的台柱,最后成为美国公民。1843年4月13日跟英国一对姐妹结婚,恩生了10个小孩,昌生了12个,姐妹吵架时,兄弟就要轮流到每个老婆家住三天。1874年恩因肺病去世,另一位不久也去世,两人均于63岁离开人间。两人的肝至今仍保存在费城的马特博物馆内。从此之后“暹罗双胞胎”(Siamese twins)就成了连体人的代名词,也因为这对双胞胎让全世界都重视到这项特殊疾病。

2 模型结构

这个图有这几个点来理解:

  • 其中的Network1和Network2按照专业的话来说就是共享权制,说白了这两个网络其实就是一个网络,在代码中就构建一个网络就行了;
  • 一般的任务,每一个样本经过模型得到一个模型的pred,然后这个pred和ground truth进行损失函数的计算,然后得到梯度;这个孪生网络则改变了这种结构,假设是图片分类的任务,把图片A输入到模型中得到了一个输出pred1,然后我再把图片B输入到模型中,得到了另外一个输出pred2,然后我这个损失函数是从pred1和pred2之间计算出来的。 就是一般情况下,模型运行一次,给出一个loss,但是在siamese net中,模型要运行两次才能得到一个loss。
  • 我个人感觉,一般的任务像是衡量一种绝对的距离,样本到标签的一个距离;但是孪生网络衡量的是样本到样本之间的一个距离。

2.1 孪生网络的用途

Siamese net衡量的是两个输入的关系,也就是两个样本相似还是不相似。

有这样的一个任务,在NIPS上,在1993年发表了文章《Signature Verification using a ‘Siamese’ Time Delay Neural Network》用于美国支票上的签名验证,检验支票上的签名和银行预留的签名是否一致。当时论文中就已经用卷积网络来做验证了...当时我还没出生。

之后,2010年Hinton在ICML上发表了《Rectified Linear Units Improve Restricted Boltzmann Machines》,用来做人脸验证,效果很好。输入就是两个人脸,输出就是same or different

可想而知,孪生网络可以做分类任务。在我看来,孪生网络不是一种网络结构,不是resnet那种的网络结构,而是一种网络的框架,我可以把resnet当成孪生网络的主干网络这样的

既然孪生网络的backbone(我们暂且这样叫,应该可以理解的把)可以是CNN,那么也自然可以是LSTM,这样可以实现词汇的语义的相似度分析

之前Kaggle上有一个question pair的比赛,衡量两个问题是否提问的是同一个问题这样的比赛,TOP1的方案就是这个孪生网络的结构Siamese net。

后来好像还有基于Siamese网络的视觉跟踪算法,这个我还没有了解,以后有机会的话我看一看这个论文。《Fully-convolutional siamese networks for object tracking》。先挖一个坑。

2.2 伪孪生网络

问题来了,孪生网络中看似两个网络,实则共享权制为一个网络,假设我们真的给他弄两个网络,那样不就可以一个是LSTM,一个CNN实现不同模态的相似度比较了?

没错,这个叫做pseudo-siamese network 伪孪生网络。一个输入是文字,一个输入是图片,判断文字描述是否是图片内容;一个是短标题,一个是长文章,判断文章内容是否是标题。(高中语文作文常年跑题选手的救星,以后给老师说这个算法说我的文章没有跑题,您要不再看看?老师会打死我吗)

不过本文和下一篇的代码都是以siamese network为核心,backbone也以CNN卷积网络和图像展开。

2.3 三胞胎

既然有了二胞胎的网络,当然也有三胞胎,叫做Triplet network《Deep metric learning using Triplet network》。据说效果已经好过Siamese network了,不知道有没有四胞胎和五胞胎。

3 损失函数

分类任务常规使用softmax加上交叉熵,但是有人提出了,这种方法训练的模型,在“类间”区分性上表现的并不好,使用对抗样本攻击就立刻不行了。后续有空讲解一下对抗样本攻击,再挖个坑。 简单的说就是,假设是人脸识别,那么每个人就是一个类别,那么你让一个模型做一个几千分类的任务,每一个类别的数据又很少的情况下,想想也会感觉到这个训练的难度。

针对这样的问题,孪生网络有两个损失函数比较近经典:

  • Contrastive Loss
  • Triplte Loss

3.1 Contrastive Loss

  • 提出论文:《Dimensionality Reduction by Learning an Invariant Mapping》

    现在我们已知:
  • 图片1 经过模型 得到pred1
  • 图片2 经过模型 得到pred2
  • pred1和pred2计算得到loss

论文中给出了这样的一个计算公式:

首先呢,这个经过模型得到的pred1和pred2都是向量,过程相当于图片经过CNN提取特征,然后得到了一个隐含向量,是一个Encoder的感觉。

然后计算这两个向量的欧氏距离,这个距离(如果模型训练的正确的话),就可以反应两个输入图像的相关性。我们每次输入两个图片,我们需要事先确定这两个图像是一类的,还是不同的,这个类似一个标签,也就是上图公式中的Y。如果是一类的,那么Y为0,如果不是,Y=1

类似于二值交叉熵损失函数,我们需要注意的是:

  • Y=0的时候,损失为:\((1-Y)L_S(D_W^i)\)
  • Y=1的时候,损失为:\(YL_D(D_W^i)\).
  • 其中论文中\(L_D,L_S\)是常数,论文中默认取0.5
  • i是一个次方的含义,论文中和常用的contrastive loss中,都是默认i=2,也就是欧氏距离的平方。
  • 对于类别是1(different类别的),我们自然是希望pred1和pred2的欧氏距离越大越好。那么这个大到什么程度是个头呢?损失函数是往小的方向移动,那么需要做什么呢?增加一个margin,当作最大的距离。如果pred1和pred2的距离大于margin,那么就认为这两个样本距离足够大,就当其的损失为0。所以写的方法就是:\(max(margin-distance,0)\).
  • 上图中的W我理解为神经网络的weight,然后\(\vec X_1\),表示要输入的原图片。

所以损失函数就变成这个样子:

总结一下,这里面需要注意的应该就是对于different的两个图片,需要设置一个margin,然后小于margin的计算损失,大于margin的损失为0.

3.2 Contrastive Loss pytorch

# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
""" def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + # calmp夹断用法
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) return loss_contrastive

其中唯一需要谈一下的可能就是torch.nn.functional.pariwise_distance,

这个就是计算对应元素的欧氏距离,举个例子:

import torch
import torch.nn.functional as F
a = torch.Tensor([[1,2],[3,4]])
b = torch.Tensor([[10,20],[30,40]])
F.pairwise_distance(a,b)

输出为:

然后看一下这个数字是不是欧氏距离:

没问题的啊

3.3 Triplte Loss

  • 提出论文:《FaceNet: A Unified Embedding for Face Recognition and Clustering》

这个论文提出了FactNet,然后使用了Triplte Loss。Triplet Loss即三元组损失,我们详细来介绍一下。

  • Triplet Loss定义:最小化锚点和具有相同身份的正样本之间的距离,最小化锚点和具有不同身份的负样本之间的距离。这个其实应该是三胞胎网络的损失函数,同时输入三个样本,一个图片,然后一个same类别的图片和一个different图片。
  • Triplet Loss的目标:Triplet Loss的目标是使得相同标签的特征在空间位置上尽量靠近,同时不同标签的特征在空间位置上尽量远离,同时为了不让样本的特征聚合到一个非常小的空间中要求对于同一类的两个正例和一个负例,负例应该比正例的距离至少远margin。如下图所示:

这个的话我们要如何构建损失函数呢?已知我们想要的:

  • 让anchor和positive得到的向量的欧氏距离越小越好;
  • 让anchor和negative得到的向量的欧氏距离越大越好;

所以期望下面这个公式成立:

  • 简单的说就是anchor和positive的距离要比anchor和negative的距离小,而且这个差距要至少要大于\(\alpha\)。个人的思考是,这里的T,是三元组的集合。对于一个数据集,往往可以构建出非常多的三元组,因此我个人感觉这种任务一般用在类别多,数据量较少的任务中,不然三元组数量爆炸了

3.4 Triplte Loss keras

这里有一个keras的triplte loss的代码

def triplet_loss(y_true, y_pred):
"""
Triplet Loss的损失函数
""" anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:] # 欧式距离
pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True)
neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True)
basic_loss = pos_dist - neg_dist + TripletModel.MARGIN loss = K.maximum(basic_loss, 0.0) print "[INFO] model - triplet_loss shape: %s" % str(loss.shape)
return loss

参考文献:

[1] Momentum Contrast for Unsupervised Visual Representation Learning, 2019, Kaiming He Haoqi Fan Yuxin Wu Saining Xie Ross Girshick

[2] Dimensionality Reduction by Learning an Invariant Mapping, 2006, Raia Hadsell, Sumit Chopra, Yann LeCun

孪生网络入门(上) Siamese Net及其损失函数的更多相关文章

  1. 孪生网络入门(下) Siamese Net分类服装MNIST数据集(pytorch)

    主题列表:juejin, github, smartblue, cyanosis, channing-cyan, fancy, hydrogen, condensed-night-purple, gr ...

  2. Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法:   ...

  3. 孪生网络(Siamese Network)在句子语义相似度计算中的应用

    1,概述 在NLP中孪生网络基本是用来计算句子间的语义相似度的.其结构如下 在计算句子语义相似度的时候,都是以句子对的形式输入到网络中,孪生网络就是定义两个网络结构分别来表征句子对中的句子,然后通过曼 ...

  4. 关于 Android 5.0 原生系统网络图标上的感叹号问题解决方法

    解决方案 adb shell settings put global captive_portal_server g.cn 参考 关于 android 5.0 网络图标上的感叹号及其解决办法

  5. Android开发之从网络URL上下载JSON数据

    网络下载拉取数据中,json数据是一种格式化的xml数据,非常轻量方便,效率高,体验好等优点,下面就android中如何从给定的url下载json数据给予解析: 主要使用http请求方法,并用到Htt ...

  6. Java面向对象 网络编程 上

     Java面向对象 网络编程 上 知识概要:                     (1)网络模型 (2)网络通讯要素 (3)UDP TCP 概念 (4)Socket (5)UDP TCP 传输 ...

  7. 第78节:Java中的网络编程(上)

    第78节:Java中的网络编程(上) 前言 网络编程涉及ip,端口,协议,tcp和udp的了解,和对socket通信的网络细节. 网络编程 OSI开放系统互连 网络编程指IO加网络 TCP/IP模型: ...

  8. Web API 令牌(秘钥是双方约定的,并不在网络连接上传输)

    http://blog.csdn.net/qq289523052/article/details/47750021 秘钥是双方约定的,并不在网络连接上传输 Web API数据传输加密 2015-08- ...

  9. Docker源码分析(七):Docker Container网络 (上)

    1.前言(什么是Docker Container) 如今,Docker技术大行其道,大家在尝试以及玩转Docker的同时,肯定离不开一个概念,那就是“容器”或者“Docker Container”.那 ...

随机推荐

  1. c#视频位置

    static void Main(string[] args)        { string scoure = @"C:\Documents and Settings\Administra ...

  2. C++ 数据结构 2:栈和队列

    1 栈 1.1 栈的基本概念 栈(stack)又名堆栈,它是一种 运算受限的线性表.限定 仅在表尾进行插入和删除操作 的线性表.表尾被称为栈顶,相对地,把另一端称为栈底. 1.1.1 特点 它的特殊之 ...

  3. 解决 cannot resolve 依赖包的问题

    在maven import的时候 报这样的错误 之前也经常碰到这样的错误,通过reimport.清缓存等方法都可以解决.但这次试了好多次都还是这样,查看maven后发现我pom文件里也没写错. 最后是 ...

  4. fork系统调用关于如何让子进程先运行{wait(),waitpid(),sleep()}

    在父进程里面调用wait()和waitpid()可以确保子进程先运行,因为当子进程运行完后会变成僵尸进程,此时会发送一个信号给父进程,父进程接受到信号才会运行. 有人或许会问如果在父进程调用wait之 ...

  5. subString引起的index out of range

    特别注意!!!低级坑 subString(begin,end)   subList()均存在这个问题. 当end>String.size(),则index out of range!!!

  6. binary hacks读数笔记(ld 链接讲解 二)

    这块将介绍一下ld链接命令的具体使用.ld的作用:ld是GNU binutils工具集中的一个,是众多Linkers(链接器)的一种.完成的功能自然也就是链接器的基本功能:把各种目标文件和库文件链接起 ...

  7. linux下制作软件包安装服务器

    linux下的软件包在有网络的情况下比较好安装,在ubuntu下,更新sourcelist,然后使用apt-get就可以很方便的安装包,在centos下面,更新yum列表,然后使用yum也可以进行方便 ...

  8. python分类

    python是一种动态解释性的强类型语言. python下分几个类别,分别是cpython,jypython,ironpython,pypy等等,这些属于不同的解释器,但编写规范只有一个就是pytho ...

  9. 算法:Common Subsequence(动态规划 Java 最长子序列)

    Description A subsequence of a given sequence is the given sequence with some elements (possible non ...

  10. Hbase启动报java异常

    在conf文件夹下的hbase-env.sh文件中的j添加ava_home的环境变量, ******************************************************** ...