调试pytorch DDP训练代码

发布于:2025-02-11 ⋅ 阅读:(11) ⋅ 点赞:(0)

背景

Pytorch提供了Distributed Data Parallel (DDP)工具以便在多机多卡上并行训练,并提供了torchrun指令来启动。然而,torchrun指令启动不便于debug。可以通过修改成等价mp.spawn启动方式先debug,完成后再转回torchrun指令启动正式训练。

流程

假设原始DDP训练代码是:

import torch.distributed as dist

def main():
	args.local_rank = int(os.environ["LOCAL_RANK"])
    args.world_size = int(os.environ["WORLD_SIZE"])
    args.rank = int(os.environ["RANK"])
    dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
    torch.cuda.set_device(args.local_rank)
    ......

if __name__ == "__main__":
    main()

通过以下指令启动DDP训练:

torchrun --nnodes=1 --nproc_per_node=4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py

为了在IDE(例如Pycharm)内debug,修改成以下代码,直接在IDE内debug即可。

import torch.distributed as dist
import torch.multiprocessing as mp

def main(rank, world_size):
	args.local_rank = rank
    args.world_size = world_size
    args.rank = rank
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
    torch.cuda.set_device(args.local_rank)
    ......

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)