PyTorch分布式训练调试方法(跟踪调用过程)
背景
在分布式深度学习训练场景中,通信操作(如AllReduce、Send/Recv)和CUDA操作的时序问题往往难以调试。本工具通过以下方式提供调试支持:
- 拦截所有PyTorch张量操作并记录调用栈
- 监控分布式通信操作的完整生命周期
- 自动生成带时间戳的详细日志
- 支持多GPU并行调试(每个进程独立日志)
方法
本工具采用PyTorch官方推荐的扩展方式实现:
- TorchDispatchMode:拦截所有张量操作
- Monkey Patch:重写分布式通信原语
- 异步日志:确保日志完整性
- 调用栈追踪:定位操作发起位置
操作步骤
# 禁用可能产生干扰的第三方扩展库
import sys
sys.modules['apex'] = None
sys.modules['transformer_engine'] = None
import os
import torch
from functools import partial
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import time
import os
import pickle
import inspect
# 初始化日志系统(每个进程独立日志)
glog=open(f"trace_rank{
os.environ['RANK']}.log","w")
def save_info(msg):
"""带缓冲刷新的日志记录函数"""
glog.write(f"{
msg}\n")
glog.flush()
@dataclass
class _ProfilerState:
cls: Any
object: Any = None
class TorchDumpDispatchMode(TorchDispatchMode):
def __init__(self,parent):
super().__init__()
self.parent=parent
def is_allow_dump(self,name):
"""过滤不需要记录的操作"""
black_list=["_has_compatible_shallow_copy_type"]
for i in black_list:
if name.find(i)>=0:
return False
return True
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
func_packet = func._overloadpacket
op_name=f"{
func}"
enable_dump