腾讯开源人脸识别训练代码TFace 中关于all_gather层的实现如下。接下来解释为什么backward要进行reduce相加操作。

https://github.com/Tencent/TFace

class AllGatherFunc(Function):
""" AllGather op with gradient backword
"""
@staticmethod
def forward(ctx, tensor, *gather_list):
gather_list = list(gather_list)
dist.all_gather(gather_list, tensor)
return tuple(gather_list) @staticmethod
def backward(ctx, *grads):
grad_list = list(grads)
rank = dist.get_rank()
grad_out = grad_list[rank]
dist_ops = [
dist.reduce(grad_out, rank, ReduceOp.SUM, async_op=True) if i == rank else
dist.reduce(grad_list[i], i, ReduceOp.SUM, async_op=True) for i in range(dist.get_world_size())
]
for _op in dist_ops:
_op.wait()
grad_out *= len(grad_list) # cooperate with distributed loss function
return (grad_out, *[None for _ in range(len(grad_list))])
AllGather = AllGatherFunc.apply

下面用示意图来描述大规模人脸分类的过程,如下图。

结合下面示意图和公式表达来理解。

B: batch size, d: feature dimension, K: gpu number, C: class number, \(c_j\): class number of j-th gpu

(1)\(F_j \in R^{B*d}\): 第j块GPU上特征

(2)\(F_{total} = torch.cat((F_0, F_1, ^, F_{K-1} )) \in R^{KB*d}\): 表示所有的K个GPU上特征合并在一起

(3)\(W_j \in R^{d*c_j}\):第j块GPU上的分类权重

(4)\(logit_j=F_{total}W_j \in R^{KB*c_j}\): 这里简化分类层为常规线性变换。(下面的公式中\(y_j\)就表示\(logit_j\))

\(\frac {\partial L_j}{\partial F_{total}} = \frac{\partial L_j}{\partial y_j}* \frac{\partial y_j}{\partial F_{total}}=\frac{\partial L_j}{\partial y_j}*W_j^T\),(\(R^{KB*c_j}*R^{c_j*d}=R^{KB*d}\),数据维度是可以对应上的)。

  可以看出每块GPU上产生的对全体特征向量的梯度维度都是一样(这个是肯定的),每块GPU上产生梯度是通过上述链式法则得到的,得到梯度的公式中,分两个部分相乘,一个是对logit值的导数,一个是当前卡上局部分类权重W的导数。对于每块卡而言这两部分都不一样。也就是每块gpu都对全体特征向量\(F_{total}\)都产生梯度。总的loss是各个GPU上loss先求和再归约,因此在求对logit梯度时,也除以了总的样本数量(KB),然后对全体特征向量\(F_{total}\)在allgather层要进行相加。\(\frac{\partial L}{\partial F_{total}}=\frac{1}{KB}\sum _{j=0}^{j=K-1}\frac {\partial L_j}{\partial F_{total}} =\frac{1}{KB}\sum _{j=0}^{j=K-1}\frac{\partial L_j}{\partial y_j}*W_j^T=\sum _{j=0}^{j=K-1}\frac{1}{KB}\frac{\partial L_j}{\partial y_j}*W_j^T\)。

 可是不明白上述代码为什么要乘以GPU的数量,对应代码为:grad_out *= len(grad_list)

大规模人脸分类—allgather操作(2)的更多相关文章

  1. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  2. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  3. 用keras的cnn做人脸分类

    keras介绍 Keras是一个简约,高度模块化的神经网络库.采用Python / Theano开发. 使用Keras如果你需要一个深度学习库: 可以很容易和快速实现原型(通过总模块化,极简主义,和可 ...

  4. wordpress搜索结果排除某个分类如何操作

    我们知道wordpress的搜索结果页search.php和分类页category.php是一样的,但是客户的网站是功能比较多的系统,有新闻又有产品,如果搜索结果只想展示产品要如何操作呢?随ytkah ...

  5. SQL分类-DDL_操作数据库_创建&查询

    SQL分类 1.DDL(Data Definition Language)数据定义语言 用来定义数据库对象:数据库,表,列等.关键字:create , drop, alter 等 2.DML(Data ...

  6. python集合的分类与操作

    如图: 集合的炒作分类: 确定大小 测试项的成员关系 遍历集合 获取一个字符串表示 测试相等性 连接两个集合 转换为另一种类型的集合 插入一项 删除一项 替换一项 访问或获取一项

  7. Python函数分类及操作

    为什么使用函数? 答:函数的返回值可以确切知道整个函数执行的结果   函数的定义:1.数学意义的函数:两个变量:自变量x和因变量y,二者的关系                      2.Pytho ...

  8. .NET做人脸识别并分类

    .NET做人脸识别并分类 在游乐场.玻璃天桥.滑雪场等娱乐场所,经常能看到有摄影师在拍照片,令这些经营者发愁的一件事就是照片太多了,客户在成千上万张照片中找到自己可不是件容易的事.在一次游玩等活动或家 ...

  9. face recognition[翻译][深度学习理解人脸]

    本文译自<Deep learning for understanding faces: Machines may be just as good, or better, than humans& ...

  10. face recognition[翻译][深度人脸识别:综述]

    这里翻译下<Deep face recognition: a survey v4>. 1 引言 由于它的非侵入性和自然特征,人脸识别已经成为身份识别中重要的生物认证技术,也已经应用到许多领 ...

随机推荐

  1. 新手入门Neo4j,手把手完整教学

    1. 图数据库Neo4j简介 1.1 什么是图数据库 图数据库:是基于图论实现的一种NoSQL数据库,其数据结构是和查询方式是以图论为基础的,图数据库主要用于存储更多的连接数据. 图论:用多个节点代表 ...

  2. Spark应用程序第三方jar文件依赖解决方案

    第一种方式 操作:将第三方jar文件打包到最终形成的spark应用程序jar文件中 应用场景:第三方jar文件比较小,应用的地方比较少 第二种方式 操作:使用spark-submit提交命令的参数: ...

  3. windows文件夹被占用的解除办法

    1.第一步,按下快捷键组合 ctrl alt del,打开任务管理器窗口,点击上方菜单栏中的性能选项. 2. 第二步,在性能页面下找到打开资源监视器按钮并点击. 3. 第三步,进入资源监视器页面,点击 ...

  4. cuda、cudnn、tnesorrt的查看安装

    1.首先本地查看cuda已安装的版本 11.7输入命令:[nvcc -V]输出:nvcc: NVIDIA (R) Cuda compiler driverCopyright (c) 2005-2022 ...

  5. 提高NTC测温精度(转发)

    (一)一般精度要求:采样数据的获取,直接采用恒流源(或恒压源)上拉方式.见图(2)所示.  原理:将恒流源(或恒压源)直接作用于NTC热敏电阻Rt上,当被测对象的温度发生变化,NTC热敏电阻的阻值Rt ...

  6. Java语言中的复合运算符会自动进行类型转换

    计算1/1-1/2+1/3+--+1/99-1/100 public class LoopControlExercise08{ public static void main(String[] arg ...

  7. Quartz.Net的简单使用

    1.安装Quartz.Net Install-Package Quartz -Version 2.5.0 2.需要执行定时任务的代码,新建一个类,继承IJob接口,并实现该接口 public clas ...

  8. 第六章:用Python实现自动发送邮件和发送钉钉消息

    目录 发送邮件源码 发送钉钉消息源码 源码地址 本文可以学习到以下内容: 使用requests库发送钉钉消息 使用email和smtplib库发送邮件 使用163邮箱服务,自动发送邮件及附件 发送邮件 ...

  9. 【python】pip的问题

    在更新pip的时候因为安装pip需要管理员权限,但是呢,又把pip给uninstall了 所以结果就是没有pip这个模块了 https://stackoverflow.com/questions/18 ...

  10. 连接HBase

    单线连接HBasepublic class HBaseConnection { public static void main(String[] args) throws IOException { ...