计图MPI分布式多卡
计图MPI分布式多卡
计图分布式基于MPI(Message Passing Interface),主要阐述使用计图MPI,进行多卡和分布式训练。目前计图分布式处于测试阶段。
计图MPI安装
计图依赖OpenMPI,用户可以使用如下命令安装OpenMPI:
sudo apt install openmpi-bin openmpi-common libopenmpi-dev
计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,输出如下信息:
[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc
如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:export mpicc_path=/you/mpicc/path
OpenMPI安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式,自动完成并行操作。
# 单卡训练代码
python3.7 -m jittor.test.test_resnet
# 分布式多卡训练代码
mpirun -np 4 python3.7 -m jittor.test.test_resnet
# 指定特定显卡的多卡训练代码
CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet
便捷性的背后,计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发,均在Python前端完成,让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。
基于这些mpi算子接口,研发团队已经集成了如下三种分布式相关的算法:
- 分布式数据并行加载
- 分布式优化器
- 分布式同步批归一化层
用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是所有节点的batch size之和,也就是总batch size,不是单个节点接收到的batch size。
MPI接口
目前MPI开放接口如下:
- jt.mpi: 计图的MPI模块,当计图不在MPI环境下时,jt.mpi == None, 用户可以用这个判断是否在mpi环境下。
- jt.Module.mpi_param_broadcast(root=0): 将模块的参数从root节点广播给其他节点。
- jt.mpi.mpi_reduce(x, op='add', root=0): 将所有节点的变量x使用算子op,reduce到root节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。
- jt.mpi.mpi_broadcast(x, root=0): 将变量x从root节点广播到所有节点。
- jt.mpi.mpi_all_reduce(x, op='add'): 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。
实例:MPI实现分布式同步批归一化层
下面的代码是使用计图实现分布式同步批,归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下:
# 将均值和方差,通过all reduce同步到所有节点
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
注:计图内部已经实现了同步的批归一化层,用户不需要自己实现
分布式同步批归一化层的完整代码:
class BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
# 将均值和方差,通过all reduce同步到所有节点
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
计图MPI分布式多卡的更多相关文章
- 计图(Jittor) 1.1版本:新增骨干网络、JIT功能升级、支持多卡训练
计图(Jittor) 1.1版本:新增骨干网络.JIT功能升级.支持多卡训练 深度学习框架-计图(Jittor),Jittor的新版本V1.1上线了.主要变化包括: 增加了大量骨干网络的支持,增强了辅 ...
- openlayers-统计图显示(中国区域高亮)
openlayers版本: v3.19.1-dist 统计图效果: 案例下载地址:https://gitee.com/kawhileonardfans/openlayers-examp ...
- 用动图讲解分布式 Raft
一.Raft 概述 Raft 算法是分布式系统开发首选的共识算法.比如现在流行 Etcd.Consul. 如果掌握了这个算法,就可以较容易地处理绝大部分场景的容错和一致性需求.比如分布式配置系统.分布 ...
- 8.3 MPI
MPI 模型 如图MPI的各个运算节点是分布式的.每一个节点可以视为是一个“Thread”,但这里的不同之处在于这些节点没有所谓的共享内存,或者说Global Memory.所以,在后面也会看到,一般 ...
- Horovod 分布式深度学习框架相关
最近需要 Horovod 相关的知识,在这里记录一下,进行备忘: 分布式训练,分为数据并行和模型并行两种: 模型并行:分布式系统中的不同GPU负责网络模型的不同部分.神经网络模型的不同网络层被分配到不 ...
- Samsung S4卡屏卡在开机画面的不拆机恢复照片一例
大家好!欢迎再次来到我Dr.wonder的世界, 今天我给你们带来Samsung S4 I9508 卡屏开在开机画面的恢复!非常de经典. 首先看图 他开机一直卡在这里, 然后 ,我们使用专业仪器,在 ...
- 云时代的分布式数据库:阿里分布式数据库服务DRDS
发表于2015-07-15 21:47| 10943次阅读| 来源<程序员>杂志| 27 条评论| 作者王晶昱 <程序员>杂志数据库DRDS分布式沈询 摘要:伴随着系统性能.成 ...
- Spark入门实战系列--9.Spark图计算GraphX介绍及实例
[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .GraphX介绍 1.1 GraphX应用背景 Spark GraphX是一个分布式图处理 ...
- 学习笔记:The Log(我所读过的最好的一篇分布式技术文章)
前言 这是一篇学习笔记. 学习的材料来自Jay Kreps的一篇讲Log的博文. 原文很长,但是我坚持看完了,收获颇多,也深深为Jay哥的技术能力.架构能力和对于分布式系统的理解之深刻所折服.同时也因 ...
随机推荐
- C/C++ 介绍的PE文件遍历工具
在前面的笔记中,我总结了Pe结构的一些结构含义,并手动编写了几段PE结构遍历代码,这里我直接把之前的C语言代码进行了封装,形成了一个命令行版的PE文件查看工具,该工具只有20kb,但却可以遍历出大部分 ...
- Backdoor.Zegost木马病毒分析(一)
http://blog.csdn.net/qq1084283172/article/details/50413426 一.样本信息 样本名称:rt55.exe 样本大小: 159288 字节 文件类型 ...
- hdu4982 暴搜+剪枝(k个数和是n,k-1个数的和是平方数)
题意: 给你两个数n,k问你是否怎在这样一个序列: (1)这个序列有k个正整数,且不重复. (2)这k个数的和是n. (3)其中有k-1个数的和是一个平方数. ...
- UVA11134传说中的车(放棋子)
题意: 给你一个n*n的棋盘,让你在棋盘上放n个棋子,要求是所有棋子不能相互攻击(同行或者同列就会攻击),并且每个棋子都有一个限制,那就是必须在给定的矩形r[i]里,输出每个棋子的位置,s ...
- 神经网络与机器学习 笔记—单神经元解决XOR问题
单神经元解决XOR问题 有两个输入的单个神经元的使用得到的决策边界是输入空间的一条直线.在这条直线的一边的所有的点,神经元输出1:而在这条直线的另一边的点,神经元输出0.在输入空间中,这条直线的位置和 ...
- Day006 方法的定义和调用
方法的定义 Java的方法类似于其他语言的函数,是一段用来完成特定功能的代码片段,一般情况下,定义一个方法包含以下语法: 方法包含一个方法头和一个方法体.下面是一个方法的所有部分: 修饰符:修饰符,这 ...
- 第三部分 IDEA创建并运行项目
可以创建一个maven,几行代码就解决了导入依赖,但是我的电脑不知道哪里出现了问题,IDEA重装,jdk重装,maven重装,都无法解决问题,找了3天,还是没有解决问题.最后只能采用手动导入包方法.看 ...
- IDEA 新建 Java 项目 (图文讲解, 良心教程)
IDEA 新建 Java 项目 (图文讲解, 良心教程) 欢迎关注博主公众号「Java大师」, 专注于分享Java领域干货文章, 关注回复「资源」, 免费领取全网最热的Java架构师学习PDF, 转载 ...
- MySQL字段类型最全解析
前言: 要了解一个数据库,我们必须了解其支持的数据类型.MySQL 支持大量的字段类型,其中常用的也有很多.前面文章我们也讲过 int 及 varchar 类型的用法,但一直没有全面讲过字段类型,本篇 ...
- “深度评测官”——记2020BUAA软工软件案例分析作业
项目 内容 这个作业属于哪个课程 2020春季计算机学院软件工程(罗杰 任建) 这个作业的要求在哪里 个人博客作业-软件案例分析 我在这个课程的目标是 完成一次完整的软件开发经历并以博客的方式记录开发 ...