(原)SphereFace及其pytorch代码
转载请注明出处:
http://www.cnblogs.com/darkknightzh/p/8524937.html
论文:
SphereFace: Deep Hypersphere Embedding for Face Recognition
https://arxiv.org/abs/1704.08063
http://wyliu.com/papers/LiuCVPR17v3.pdf
官方代码:
https://github.com/wy1iu/sphereface
pytorch代码:
https://github.com/clcarwin/sphereface_pytorch
说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。
传统的交叉熵公式如下:
${{L}_{i}}=-\log \frac{{{e}^{W_{yi}^{T}{{x}_{i}}+{{b}_{yi}}}}}{\sum\nolimits_{j}{{{e}^{W_{j}^{T}{{x}_{i}}+{{b}_{j}}}}}}=-\log \frac{{{e}^{\left\| {{W}_{yi}} \right\|\left\| {{x}_{i}} \right\|\cos ({{\theta }_{yi}},i)+{{b}_{yi}}}}}{\sum\nolimits_{j}{{{e}^{\left\| {{W}_{j}} \right\|\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)+{{b}_{j}}}}}}$
将W归一化到1,且不考虑偏置项,即${{b}_{j}}=0$,则上式变成:
${{L}_{\text{modified}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{yi}},i)}}}{\sum\nolimits_{j}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$
其中θ为w和x的夹角。
为了进一步限制夹角的范围,使用mθ,上式变成
${{L}_{\text{ang}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\cos (m{{\theta }_{yi}},i)}}}{{{e}^{\left\| {{x}_{i}} \right\|\cos (m{{\theta }_{yi}},i)}}+\sum\nolimits_{j\ne yi}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$
其中θ范围为$\left[ 0,\frac{\pi }{m} \right]$。
为了使得上式单调,引入$\psi ({{\theta }_{yi,i}})$:
${{L}_{\text{ang}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\psi ({{\theta }_{yi,i}})}}}{{{e}^{\left\| {{x}_{i}} \right\|\psi ({{\theta }_{yi,i}})}}+\sum\nolimits_{j\ne yi}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$
其中
$\psi ({{\theta }_{yi,i}})={{(-1)}^{k}}\cos (m{{\theta }_{yi,i}})-2k$,${{\theta }_{yi,i}}\in \left[ \frac{k\pi }{m},\frac{(k+1)\pi }{m} \right]$,$k\in \left[ 0,m-1 \right]$,$m\ge 1$
代码中引入了超参数λ,为
$\lambda =\max ({{\lambda }_{\min }},\frac{{{\lambda }_{\max }}}{1+0.1\times iterator})$
其中,${{\lambda }_{\min }}=5$,${{\lambda }_{\max }}=1500$为程序中预先设定的值。
实际的$\psi (\theta )$为
$\psi ({{\theta }_{yi}})=\frac{{{(-1)}^{k}}\cos (m{{\theta }_{yi}})-2k+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }$
对应下面代码为:
output = cos_theta * 1.0
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
对于yi处的计算,
$output(yi)=\cos ({{\theta }_{yi}})-\frac{\cos ({{\theta }_{yi}})}{1+\lambda }+\frac{\psi ({{\theta }_{yi}})}{1+\lambda }=\frac{\psi ({{\theta }_{yi}})+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }=\frac{{{(-1)}^{k}}\cos (m{{\theta }_{yi}})-2k+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }$
和上面的公式对应。
具体的代码如下(完整的代码见参考网址):
class AngleLinear(nn.Module):
def __init__(self, in_features, out_features, m = 4, phiflag=True):
super(AngleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features,out_features))
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.phiflag = phiflag
self.m = m
self.mlambda = [
lambda x: x**0, # cos(0*theta)=1
lambda x: x**1, # cos(1*theta)=cos(theta)
lambda x: 2*x**2-1, # cos(2*theta)=2*cos(theta)**2-1
lambda x: 4*x**3-3*x,
lambda x: 8*x**4-8*x**2+1,
lambda x: 16*x**5-20*x**3+5*x
] def forward(self, input): # input为输入的特征,(B, C),B为batchsize,C为图像的类别总数
x = input # size=(B,F),F为特征长度,如512
w = self.weight # size=(F,C) ww = w.renorm(2,1,1e-5).mul(1e5) #对w进行归一化,renorm使用L2范数对第1维度进行归一化,将大于1e-5的截断,乘以1e5,使得最终归一化到1.如果1e-5设置的过大,裁剪时某些很小的值最终可能小于1。注意,第0维度只对每一行进行归一化(每行平方和为1),第1维度指对每一列进行归一化。由于w的每一列为x的权重,因而此处需要对每一列进行归一化。如果要对x归一化,需要对每一行进行归一化,此时第二个参数应为0
xlen = x.pow(2).sum(1).pow(0.5) # 对输入x求平方,而后对不同列求和,再开方,得到每行的模,最终大小为第0维的,即B(由于对x不归一化,但是计算余弦时需要归一化,因而可以先计算模。但是对于w,不太懂为何不直接使用这种方式,而是使用renorm函数?)
wlen = ww.pow(2).sum(0).pow(0.5) # 对权重w求平方,而后对不同行求和,再开方,得到每列的模(理论上之前已经归一化,此处应该是1,但第一次运行到此处时,并不是1,不太懂),最终大小为第1维的,即C cos_theta = x.mm(ww) # 矩阵相乘(B,F)*(F,C)=(B,C),得到cos值,由于此处只是乘加,故未归一化
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1) # 对每个cos值均除以B和C,得到归一化后的cos值
cos_theta = cos_theta.clamp(-1,1) #将cos值截断到[-1,1]之间,理论上不截断应该也没有问题,毕竟w和x都归一化后,cos值不可能超出该范围 if self.phiflag:
cos_m_theta = self.mlambda[self.m](cos_theta) # 通过cos_theta计算cos_m_theta,mlambda为cos_m_theta展开的结果
theta = Variable(cos_theta.data.acos()) # 通过反余弦,计算角度theta,(B,C)
k = (self.m*theta/3.14159265).floor() # 通过公式,计算k,(B,C)。此处为了保证theta大于k*pi/m,转换过来就是m*theta/pi,再向上取整
n_one = k*0.0 - 1 # 通过k的大小,得到同样大小的-1矩阵,(B,C)
phi_theta = (n_one**k) * cos_m_theta - 2*k # 通过论文中公式,得到phi_theta。(B,C)
else:
theta = cos_theta.acos() # 得到角度theta,(B, C),每一行为当前特征和w的每一列的夹角
phi_theta = myphi(theta,self.m) #
phi_theta = phi_theta.clamp(-1*self.m,1) cos_theta = cos_theta * xlen.view(-1,1) # 由于实际上不对x进行归一化,此处cos_theta需要乘以B。(B,C)
phi_theta = phi_theta * xlen.view(-1,1) # 由于实际上不对x进行归一化,此处phi_theta需要乘以B。(B,C)
output = (cos_theta,phi_theta)
return output # size=(B,C,2) class AngleLoss(nn.Module):
def __init__(self, gamma=0):
super(AngleLoss, self).__init__()
self.gamma = gamma
self.it = 0
self.LambdaMin = 5.0
self.LambdaMax = 1500.0
self.lamb = 1500.0 def forward(self, input, target):
self.it += 1
cos_theta,phi_theta = input # cos_theta,(B,C)。 phi_theta,(B,C)
target = target.view(-1,1) #size=(B,1) index = cos_theta.data * 0.0 #得到和cos_theta相同大小的全0矩阵。(B,C)
index.scatter_(1,target.data.view(-1,1),1) # 得到一个one-hot矩阵,第i行只有target[i]的值为1,其他均为0
index = index.byte() # index为float的,转换成byte类型
index = Variable(index) self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it)) # 得到lamb
output = cos_theta * 1.0 #size=(B,C) # 如果直接使用output=cos_theta,可能不收敛(未测试,但其他程序中碰到过直接对输入使用[index]无法收敛,加上*1.0可以收敛的情况)
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb) # 此行及下一行将target[i]的值通过公式得到最终输出
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb) logpt = F.log_softmax(output) # 得到概率
logpt = logpt.gather(1,target) # 下面为交叉熵的计算(和focal loss的计算有点类似,当gamma为0时,为交叉熵)。
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp()) loss = -1 * (1-pt)**self.gamma * logpt
loss = loss.mean() # target = target.view(-1) # 若要简化,理论上可直接使用这两行计算交叉熵(此处未测试,在其他程序中使用后可以正常训练)
# loss = F.cross_entropy(cos_theta, target) return loss
(原)SphereFace及其pytorch代码的更多相关文章
- 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)
首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...
- (转载)PyTorch代码规范最佳实践和样式指南
A PyTorch Tools, best practices & Styleguide 中文版:PyTorch代码规范最佳实践和样式指南 This is not an official st ...
- PyTorch代码调试利器: 自动print每行代码的Tensor信息
本文介绍一个用于 PyTorch 代码的实用工具 TorchSnooper.作者是TorchSnooper的作者,也是PyTorch开发者之一. GitHub 项目地址: https://github ...
- 如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)
之前讲解了图注意力网络的官方tensorflow版的实现,由于自己更了解pytorch,所以打算将其改写为pytorch版本的. 对于图注意力网络还不了解的可以先去看看tensorflow版本的代码, ...
- pointnet.pytorch代码解析
pointnet.pytorch代码解析 代码运行 Training cd utils python train_classification.py --dataset <dataset pat ...
- 残差网络resnet理解与pytorch代码实现
写在前面 深度残差网络(Deep residual network, ResNet)自提出起,一次次刷新CNN模型在ImageNet中的成绩,解决了CNN模型难训练的问题.何凯明大神的工作令人佩服 ...
- 记录下pytorch代码从0.3版本迁移到0.4版本要做的一些更改。
1. UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to in ...
- 运行pytorch代码遇到的error解决办法
1.no CUDA-capable device is detected 首先考虑的是cuda的驱动问题,查看gpu显示是否正常,然后更新最新的cuda驱动: 第二个考虑的是cuda设备的默认参数是否 ...
- 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)
本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...
随机推荐
- [leetcode]Spiral Matrix II @ Python
原题地址:https://oj.leetcode.com/problems/spiral-matrix-ii/ 题意: Given an integer n, generate a square ma ...
- 混合开发 Hybird Ionic Angular Cordova web 跨平台 MD
Markdown版本笔记 我的GitHub首页 我的博客 我的微信 我的邮箱 MyAndroidBlogs baiqiantao baiqiantao bqt20094 baiqiantao@sina ...
- 八一八android开发规范(一种建议)
开发规范重不重要了,不言而喻.这里就给大家说一故事把——据<圣经·旧约·创世记>第11章记载,是当时人类联合起来兴建,希望能通往天堂的高塔.为了阻止人类的计划,上帝让人类说不同的语言,使人 ...
- List 多次排序
List<Patientmain> list = patientmains.OrderBy(p => p.Firstname).ThenBy(p => p.Middlename ...
- SSAS知识回放之订单数据分析
1:目标 基于已经做好的DW,利用SSAS实现一个多维数据模型的创建,通过浏览可以简单的实现订单数据的分析 2:步骤 2.1:添加数据源 如下图所示,创建一个数据仓库层的数据源连接 2.2:添加数据源 ...
- CentOS7.1 Liberty云平台之环境准备(2)
一.各节点配置Openstack源库 yum install centos-release-openstack-liberty -y 升级YUM源库 yum upgrade 安装Openstackcl ...
- vs 2017 正规表达式替换整行多行数据
((<OutputFile>..*</OutputFile>)[\S\s])[\S\s] 从 <OutputFile> 开始 到 </OutputFile&g ...
- jquery easyui tree异步加载子节点
easyui中的树可以从标记中建立,也可以通过指定一个URL属性读取数据建立.如果想建立一棵异步树,需要为每个节点指定一个id属性值,这样在加载数据时会自动向后台传递id参数. <ul id=& ...
- Python爬虫学习系列教程
最近想学一下Python爬虫与检索相关的知识,在网上看到这个教程,觉得挺不错的,分享给大家. 来源:http://cuiqingcai.com/1052.html 一.Python入门 1. Pyth ...
- Javascript 闭包(Closures)
本文内容 闭包 闭包和引用 参考资料 闭包是 JavaScript 的重要特性,非常强大,可用于执行复杂的计算,可并不容易理解,尤其是对之前从事面向对象编程的人来说,对 JavaScript 认识和编 ...