现在有四张卡,但是部署在windows10系统上,想尝试下在windows上使用单机多卡进行分布式训练,网上找了一圈硬是没找到相关的文章。以下是踩坑过程。

首先,pytorch的版本必须是大于1.7,这里使用的环境是:

pytorch==1.12+cu11.6
四张4090显卡
python==3.7.6

使用nn.DataParallel进行分布式训练

这一种方式较为简单:

首先我们要定义好使用的GPU的编号,GPU按顺序依次为0,1,2,3。gpu_ids可以通过命令行的形式传入:

gpu_ids = args.gpu_ids.split(',')
gpu_ids = [int(i) for i in gpu_ids]
torch.cuda.set_device('cuda:{}'.format(gpu_ids[0]))

创建模型后用nn.DataParallel进行处理,

 model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])

对,没错,只需要这么两步就行了。需要注意的是保存模型后进行加载时,需要先用nn.DataParallel进行处理,再加载权重,不然参数名没对齐会报错。

checkpoint = torch.load(checkpoint_path)
model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
r_model.load_state_dict(checkpoint['state_dict'])

如果不使用分布式加载模型,你需要对权重进行映射:

new_start_dict = {}
for k, v in checkpoint['state_dict'].items():
new_start_dict["module." + k] = v
model.load_state_dict(new_start_dict)

使用Distributed进行分布式训练

首先了解一下概念:

node:主机数,单机多卡就一个主机,也就是1。

rank:当前进程的序号,用于进程之间的通讯,rank=0的主机为master节点。

local_rank:当前进程对应的GPU编号。

world_size:总的进程数。

在windows中,我们需要在py文件里面使用:

import os
os.environ["CUDA_VISIBLE_DEVICES]='0,1,3'

来指定使用的显卡。

假设现在我们使用上面的三张显卡,运行时显卡会重新按照0-N进行编号,有:

[38664] rank = 1, world_size = 3, n = 1, device_ids = [1]
[76032] rank = 0, world_size = 3, n = 1, device_ids = [0]
[23208] rank = 2, world_size = 3, n = 1, device_ids = [2]

也就是进程0使用第1张显卡,进行1使用第2张显卡,进程2使用第三张显卡。

有了上述的基本知识,再看看具体的实现。

使用torch.distributed.launch启动

使用torch.distributed.launch启动时,我们必须要在args里面添加一个local_rank参数,也就是:

parser.add_argument("--local_rank", type=int, default=0)

1、初始化:

import torch.distributed as dist

env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
current_work_dir = os.getcwd()
init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]),
world_size=int(env_dict["WORLD_SIZE"]))

这里需要重点注意,这种启动方式在环境变量中会存在RANK和WORLD_SIZE,我们可以拿来用。backend必须指定为gloo,init_method必须是file:///,而且每次运行完一次,下一次再运行前都必须删除生成的ddp_example,不然会一直卡住。

2、构建模型并封装

local_rank会自己绑定值,不再是我们--local_rank指定的。

 model.cuda(args.local_rank)
r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)

3、构建数据集加载器并封装

  train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)

4、计算损失函数

我们把每一个GPU上的loss进行汇聚后计算。

def loss_reduce(self, loss):
rt = loss.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= self.args.local_world_size
return rt loss = self.criterion(outputs, labels)
torch.distributed.barrier()
loss = self.loss_reduce(loss)

注意打印相关信息和保存模型的时候我们通常只需要在local_rank=0时打印。同时,在需要将张量转换到GPU上时,我们需要指定使用的GPU,通过local_rank指定就行,即data.cuda(args.local_rank),保证数据在对应的GPU上进行处理。

5、启动

在windows下需要把换行符去掉,且只变为一行。

python -m torch.distributed.launch \
--nnode=1 \
--node_rank=0 \
--nproc_per_node=3 \
main_distributed.py \
--local_world_size=3 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

nproc_per_node、local_world_size和GPU的数目保持一致。

使用torch.multiprocessing启动

使用torch.multiprocessing启动和使用torch.distributed.launch启动大体上是差不多的,有一些地方需要注意。

mp.spawn(main_worker, nprocs=args.nprocs, args=(args,))

main_worker是我们的主运行函数,dist.init_process_group要放在这里面,而且第一个参数必须为local_rank。即main_worker(local_rank, args)

nprocs是进程数,也就是使用的GPU数目。

args按顺序传入main_worker真正使用的参数。

其余的就差不多。

启动指令:

python main_mp_distributed.py \
--local_world_size=4 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

最后需要说明的,假设我们设置的batch_size=64,那么实际上的batch_size = int(batch_size / GPU数目)。

附上完整的基于bert的中文文本分类单机多卡训练代码:https://github.com/taishan1994/pytorch_bert_chinese_text_classification

参考

https://github.com/tczhangzhi/pytorch-distributed

https://murphypei.github.io/blog/2020/09/pytorch-distributed

https://pytorch.org/docs/master/distributed.html?highlight=all_gather#torch.distributed.all_gather

https://github.com/lesliejackson/PyTorch-Distributed-Training

https://github.com/pytorch/examples/blob/ddp-tutorial-code/distributed/ddp/example.py

996黄金一代:[原创][深度][PyTorch] DDP系列第一篇:入门教程

「新生手册」:PyTorch分布式训练 - 知乎 (zhihu.com)

windows下使用pytorch进行单机多卡分布式训练的更多相关文章

  1. windows下安装pytorch

    安装: https://blog.csdn.net/xiangxianghehe/article/details/80103095 Windows下通过pip安装PyTorch 0.4.0 impor ...

  2. 云原生的弹性 AI 训练系列之二:PyTorch 1.9.0 弹性分布式训练的设计与实现

    背景 机器学习工作负载与传统的工作负载相比,一个比较显著的特点是对 GPU 的需求旺盛.在之前的文章中介绍过(https://mp.weixin.qq.com/s/Nasm-cXLtJObjLwLQH ...

  3. Pytorch使用分布式训练,单机多卡

    pytorch的并行分为模型并行.数据并行 左侧模型并行:是网络太大,一张卡存不了,那么拆分,然后进行模型并行训练. 右侧数据并行:多个显卡同时采用数据训练网络的副本. 一.模型并行 二.数据并行 数 ...

  4. Windows 下单机最大TCP连接数

    在做Socket 编程时,我们经常会要问,单机最多可以建立多少个 TCP 连接,本文将介绍如何调整系统参数来调整单机的最大TCP连接数. Windows 下单机的TCP连接数有多个参数共同决定,下面一 ...

  5. windows下配置cuda9.0和pytorch

    今天看了看pytorch官网竟然支持windows了,赶紧搞一个. 下载cuda 9.0  https://developer.nvidia.com/cuda-downloads 下载anaconda ...

  6. PyTorch在64位Windows下的Conda包(转载)

    PyTorch在64位Windows下的Conda包 昨天发了一篇PyTorch在64位Windows下的编译过程的文章,有朋友觉得能不能发个包,这样就不用折腾了.于是,这个包就诞生了.感谢@晴天14 ...

  7. redis在Windows下以后台服务一键搭建哨兵(主从复制)模式(单机)

    redis在Windows下以后台服务一键搭建哨兵(主从复制)模式(单机) 一.概述 此教程介绍如何在windows系统中单机布置redis哨兵模式(主从复制),同时要以后台服务的模式运行.布置以脚本 ...

  8. redis在Windows下以后台服务一键搭建集群(单机--伪集群)

    redis在Windows下以后台服务一键搭建集群(单机--伪集群) 一.概述 此教程介绍如何在windows系统中同一台机器上布置redis伪集群,同时要以后台服务的模式运行.布置以脚本的形式,一键 ...

  9. windows下PyTorch安装之路记录

    最近两天被windows下pytorch的安装给搞得很烦了,不过在今天终于安装成功了,如下图所示 下面详细说下此次安装的详细记录吧.我的电脑环境是Windows10+cuda9.0+cudnn7.1. ...

  10. Windows下nacos单机部分发现的坑

    一.下载nacos的地址: https://github.com/alibaba/nacos/releases 下载 nacos-server-1.3.2.tar.gz    就好 二.在Window ...

随机推荐

  1. Mybatis配置报错:Failed to configure a DataSource: 'url' attribute is not specified and no embe...

    问题分析及解决方案 问题原因: Mybatis没有找到合适的加载类,其实是大部分spring - datasource - url没有加载成功,分析原因如下所示. DataSourceAutoConf ...

  2. 怎样修改linux内核

    1.先查看linux内核 uname -a 2.打开内核配置文件 sudo vi /etc/default/grub 3.跟新grub文件 sudo update-grub 4.最后重启电脑 sudo ...

  3. conda相关的设置备忘

    因为默认channel已经没有3.4.4(最后一个支持xp的python3)了,为了添加这个的版本,尝试先用conda-forge channel: conda create -n myenv pyt ...

  4. XML_DOM4J_20200415

    package com.wy.xml; import java.io.File;import java.util.Iterator; import org.dom4j.Attribute; impor ...

  5. OSIDP-I/O管理和磁盘调度-11

    I/O设备 I/O设备分类:人可读.机器可读和远程通信三类. I/O设备之间的差别: 1.数据传送速率 2.应用领域 3.控制的复杂性 4.传送单位 5.数据表示形式 6.错误条件 I/O功能的组织 ...

  6. swftools工具将pdf文件转换为swf文件 文字丢失

    开发客户网站时遇到了一个需求,客户要求后台上传pdf文件,前台能以翻书的形式直接访问. 首先想到的是使用js解决,用户访问前端页面时,php将文件路径发送给js,让js呈现出来翻书的效果.在网上百度了 ...

  7. 深入理解Linux系统调用

    1.系统调用号查询 我的学号位数是08,在64位调用表里可以查到对应的系统调用函数是__x64_sys_lseek 2.lseek函数 由于没用过该函数,所以先去了解一下这个函数的作用: 直白的说就是 ...

  8. Linux查看CPU 内存命令

    查看CPU 内存命令:https://www.cnblogs.com/ggjucheng/archive/2013/01/14/2859613.html 查看某一进程内存占用:ps -ef 获取PID ...

  9. 从XXE漏洞修复引起Not supported: http://javax.xml.XMLConstants/property/accessExternalDTD说到SPI机制

    引子  在使用Fortify扫描时代码报XML External Entity Injection,此漏洞为xml实体注入漏洞,XXE攻击可利用在处理时动态构建文档的 XML 功能.修复方案也包含了增 ...

  10. mysql中数据库的并发事务问题

    1.脏读(dirty-read):如果第二个事务查询到第一个事务还未提交的更新数据,就会形成脏读. 2.幻读/虚读(phantom read):一个事务执行两次,如果出现第二次事务执行比第一次多一些或 ...