PyTorch分布式训练调试方法(跟踪调用过程)

发布于:2025-04-19 ⋅ 阅读:(20) ⋅ 点赞:(0)

PyTorch分布式训练调试方法(跟踪调用过程)

背景

在分布式深度学习训练场景中,通信操作(如AllReduce、Send/Recv)和CUDA操作的时序问题往往难以调试。本工具通过以下方式提供调试支持:

  1. 拦截所有PyTorch张量操作并记录调用栈
  2. 监控分布式通信操作的完整生命周期
  3. 自动生成带时间戳的详细日志
  4. 支持多GPU并行调试(每个进程独立日志)

方法

本工具采用PyTorch官方推荐的扩展方式实现:

  1. TorchDispatchMode:拦截所有张量操作
  2. Monkey Patch:重写分布式通信原语
  3. 异步日志:确保日志完整性
  4. 调用栈追踪:定位操作发起位置

操作步骤

# 禁用可能产生干扰的第三方扩展库
import sys
sys.modules['apex'] = None
sys.modules['transformer_engine'] = None

import os
import torch
from functools import partial
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import time
import os
import pickle
import inspect

# 初始化日志系统(每个进程独立日志)
glog=open(f"trace_rank{
     os.environ['RANK']}.log","w")

def save_info(msg):
    """带缓冲刷新的日志记录函数"""
    glog.write(f"{
     msg}\n")
    glog.flush()
    
@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent

    def is_allow_dump(self,name):
        """过滤不需要记录的操作"""
        black_list=["_has_compatible_shallow_copy_type"]
        for i in black_list:
            if name.find(i)>=0:
                return False
        return True

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        func_packet = func._overloadpacket
        op_name=f"{
     func}"
        enable_dump

网站公告

今日签到

点亮在社区的每一天
去签到