计图MPI分布式多卡

计图分布式基于MPI(Message Passing Interface),主要阐述使用计图MPI,进行多卡和分布式训练。目前计图分布式处于测试阶段。

计图MPI安装

计图依赖OpenMPI,用户可以使用如下命令安装OpenMPI:

sudo apt install openmpi-bin openmpi-common libopenmpi-dev

计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,输出如下信息:

[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc

如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:export mpicc_path=/you/mpicc/path

OpenMPI安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式,自动完成并行操作。

# 单卡训练代码

python3.7 -m jittor.test.test_resnet

# 分布式多卡训练代码

mpirun -np 4 python3.7 -m jittor.test.test_resnet

# 指定特定显卡的多卡训练代码

CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet

便捷性的背后,计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发,均在Python前端完成,让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。

基于这些mpi算子接口,研发团队已经集成了如下三种分布式相关的算法:

  • 分布式数据并行加载
  • 分布式优化器
  • 分布式同步批归一化层

用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是所有节点的batch size之和,也就是总batch size,不是单个节点接收到的batch size。

MPI接口

目前MPI开放接口如下:

  • jt.mpi: 计图的MPI模块,当计图不在MPI环境下时,jt.mpi == None, 用户可以用这个判断是否在mpi环境下。
  • jt.Module.mpi_param_broadcast(root=0): 将模块的参数从root节点广播给其他节点。
  • jt.mpi.mpi_reduce(x, op='add', root=0): 将所有节点的变量x使用算子op,reduce到root节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

  • jt.mpi.mpi_broadcast(x, root=0): 将变量x从root节点广播到所有节点。

  • jt.mpi.mpi_all_reduce(x, op='add'): 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

实例:MPI实现分布式同步批归一化层

下面的代码是使用计图实现分布式同步批,归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下:

# 将均值和方差,通过all reduce同步到所有节点

if self.sync and jt.mpi:

xmean = xmean.mpi_all_reduce("mean")

x2mean = x2mean.mpi_all_reduce("mean")

注:计图内部已经实现了同步的批归一化层,用户不需要自己实现

分布式同步批归一化层的完整代码:

class BatchNorm(Module):

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):

assert affine == None

self.sync = sync

self.num_features = num_features

self.is_train = is_train

self.eps = eps

self.momentum = momentum

self.weight = init.constant((num_features,), "float32", 1.0)

self.bias = init.constant((num_features,), "float32", 0.0)

self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()

self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()

def execute(self, x):

if self.is_train:

xmean = jt.mean(x, dims=[0,2,3], keepdims=1)

x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)

# 将均值和方差,通过all reduce同步到所有节点

if self.sync and jt.mpi:

xmean = xmean.mpi_all_reduce("mean")

x2mean = x2mean.mpi_all_reduce("mean")

xvar = x2mean-xmean*xmean

norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)

self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum

self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum

else:

running_mean = self.running_mean.broadcast(x, [0,2,3])

running_var = self.running_var.broadcast(x, [0,2,3])

norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)

w = self.weight.broadcast(x, [0,2,3])

b = self.bias.broadcast(x, [0,2,3])

return norm_x * w + b

计图MPI分布式多卡的更多相关文章

  1. 计图(Jittor) 1.1版本:新增骨干网络、JIT功能升级、支持多卡训练

    计图(Jittor) 1.1版本:新增骨干网络.JIT功能升级.支持多卡训练 深度学习框架-计图(Jittor),Jittor的新版本V1.1上线了.主要变化包括: 增加了大量骨干网络的支持,增强了辅 ...

  2. openlayers-统计图显示(中国区域高亮)

    openlayers版本: v3.19.1-dist 统计图效果:         案例下载地址:https://gitee.com/kawhileonardfans/openlayers-examp ...

  3. 用动图讲解分布式 Raft

    一.Raft 概述 Raft 算法是分布式系统开发首选的共识算法.比如现在流行 Etcd.Consul. 如果掌握了这个算法,就可以较容易地处理绝大部分场景的容错和一致性需求.比如分布式配置系统.分布 ...

  4. 8.3 MPI

    MPI 模型 如图MPI的各个运算节点是分布式的.每一个节点可以视为是一个“Thread”,但这里的不同之处在于这些节点没有所谓的共享内存,或者说Global Memory.所以,在后面也会看到,一般 ...

  5. Horovod 分布式深度学习框架相关

    最近需要 Horovod 相关的知识,在这里记录一下,进行备忘: 分布式训练,分为数据并行和模型并行两种: 模型并行:分布式系统中的不同GPU负责网络模型的不同部分.神经网络模型的不同网络层被分配到不 ...

  6. Samsung S4卡屏卡在开机画面的不拆机恢复照片一例

    大家好!欢迎再次来到我Dr.wonder的世界, 今天我给你们带来Samsung S4 I9508 卡屏开在开机画面的恢复!非常de经典. 首先看图 他开机一直卡在这里, 然后 ,我们使用专业仪器,在 ...

  7. 云时代的分布式数据库:阿里分布式数据库服务DRDS

    发表于2015-07-15 21:47| 10943次阅读| 来源<程序员>杂志| 27 条评论| 作者王晶昱 <程序员>杂志数据库DRDS分布式沈询 摘要:伴随着系统性能.成 ...

  8. Spark入门实战系列--9.Spark图计算GraphX介绍及实例

    [注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 .GraphX介绍 1.1 GraphX应用背景 Spark GraphX是一个分布式图处理 ...

  9. 学习笔记:The Log(我所读过的最好的一篇分布式技术文章)

    前言 这是一篇学习笔记. 学习的材料来自Jay Kreps的一篇讲Log的博文. 原文很长,但是我坚持看完了,收获颇多,也深深为Jay哥的技术能力.架构能力和对于分布式系统的理解之深刻所折服.同时也因 ...

随机推荐

  1. 【ElasticSearch】文档路由的原理

    ElasticSearch集群环境下新增文档如何确认该文档被分配到哪个分片中? 路由算法: ⾸先这肯定不会是随机的,否则将来要获取⽂档的时候我们就不知道从何处寻找了.实际上,这个过程是根据下⾯这个公式 ...

  2. 【Scrapy(二)】Scrapy 中的 Pipline,Item,Shell组件

    Pipline: 1.爬虫项目与爬虫的区别与关联: 一个爬虫项目可以包含多个爬虫,如下图中爬虫项目firstspider 包含多个爬虫itcst 和爬虫itcast1 2.多个爬虫是公用一套Pipli ...

  3. 使用DirectX截屏

    网上有很多关于DirectX截屏的文章,但大都是屏幕截图,很少有窗口截图,本文则两者都涉及到,先讲如何截取整个屏幕,再讲如何截取某个窗口,其实二者的区别不大,只是某个参数的设置不同而已,最后我们还将扩 ...

  4. POJ1719行列匹配

    题意:      给一个n*m的格子,每一列都有两个白色的,其余的全是黑色的,然后要选择m个格子,要求是每一列必须也只能选一个,而每一行至少选择一个,输出一种可行的方案没,输出的格式是输出m个数,表示 ...

  5. Windows核心编程 第七章 线程的调度、优先级和亲缘性(下)

    7.6 运用结构环境 现在应该懂得环境结构在线程调度中所起的重要作用了.环境结构使得系统能够记住线程的状态,这样,当下次线程拥有可以运行的C P U时,它就能够找到它上次中断运行的地方. 知道这样低层 ...

  6. 【python】Leetcode每日一题-扰乱字符串

    [python]Leetcode每日一题-扰乱字符串 [题目描述] 使用下面描述的算法可以扰乱字符串 s 得到字符串 t : 如果字符串的长度为 1 ,算法停止 如果字符串的长度 > 1 ,执行 ...

  7. Day009 二维数组

    多维数组 多维数组是数组的嵌套(数组的元素是数组,数组的数组元素的元素是数组...),比如二维数组就是一个特殊的一维数组,其每一个元素都是一个一维数组. 二维数组 int a[][]=new int ...

  8. 百度地图api逆地址解析 PHP

    一.说明:逆地址查询就是根据经纬度信息获取地址位置信息 二.参数:$lat:纬度值 ,$lng:经度值 ,$ak = 自己的AK:(百度地图开放平台对应ak链接:http://lbsyun.baidu ...

  9. Asp.NetCore Web开发之模型验证

    在开发中,验证表单数据是很重要的一环,如果对用户输入的数据不加限制,那么当错误的数据提交到后台后,轻则破坏数据的有效性,重则会导致服务器瘫痪,这是很致命的. 所以进行数据有效性验证是必要的,我们一般通 ...

  10. .Net Core平台下,添加包的引用

    一个程序的开发过程中离不开对程序集(Assembly,将程序集打包好,就成为一个.dll的包文件,它也叫动态链接库(Dynamic Link Library​))的依赖,在以前ASP.Net时代,微软 ...