PyTorch分布式训练调试方法(跟踪调用过程)
PyTorch分布式训练调试方法(跟踪调用过程)
背景
在分布式深度学习训练场景中,通信操作(如AllReduce、Send/Recv)和CUDA操作的时序问题往往难以调试。本工具通过以下方式提供调试支持:
- 拦截所有PyTorch张量操作并记录调用栈
- 监控分布式通信操作的完整生命周期
- 自动生成带时间戳的详细日志
- 支持多GPU并行调试(每个进程独立日志)
方法
本工具采用PyTorch官方推荐的扩展方式实现:
- TorchDispatchMode:拦截所有张量操作
- Monkey Patch:重写分布式通信原语
- 异步日志:确保日志完整性
- 调用栈追踪:定位操作发起位置
操作步骤
# 禁用可能产生干扰的第三方扩展库
import sys
sys.modules['apex'] = None
sys.modules['transformer_engine'] = Noneimport os
import torch
from functools import partial
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
from datetime import datetime
import time
import os
import pickle
import inspect# 初始化日志系统(每个进程独立日志)
glog=open(f"trace_rank{os.environ['RANK']}.log","w")def save_info(msg):"""带缓冲刷新的日志记录函数"""glog.write(f"{msg}\n")glog.flush()@dataclass
class _ProfilerState:cls: Anyobject: Any = Noneclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentdef is_allow_dump(self,name):"""过滤不需要记录的操作"""black_list=["_has_compatible_shallow_copy_type"]for i in black_list:if name.find(i)>=0:return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):func_packet = func._overloadpacketop_name=f"{func}"enable_dump