Hello,
I’m currently working on adding MPI backend support for both FSDP1 and FSDP2 as part of an effort to contribute to PyTorch’s distributed training stack. During testing, I encountered an error in FSDP2 during the backward pass. I would appreciate any advice on how to investigate this issue further — especially regarding debugging techniques for distributed code involving C++ internals.
Environment and Reproduction
Here is the code I used:
code
import os
import time
import argparse
import dataclasses
import torch
from torch import nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
import torch.optim as optim
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.sequential0 = nn.Linear(512, 1024)
self.sequential1 = nn.Linear(1024, 1024)
self.sequential2 = nn.Linear(1024, 512)
self.last = nn.Linear(512, 10)
def forward(self, x):
x = torch.relu(self.sequential0(x))
x = torch.relu(self.sequential1(x))
x = torch.relu(self.sequential2(x))
return self.last(x)
def fsdp_training(local_rank: int):
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device('cuda', local_rank)
model = SimpleModel().to(device)
fully_shard(model)
if rank == 0:
print(model)
optimizer = optim.AdamW(model.parameters(), lr=0.01)
input_data = torch.randn(2, 512).to(device)
target = torch.randint(0, 10, (2,)).to(device)
print(f"Rank {rank}/{world_size}: Start training")
model.train()
num_epochs = 10
for epoch in range(num_epochs):
optimizer.zero_grad()
output = model(input_data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FSDP Training')
parser.add_argument(
'--backend', type=str, choices=['mpi', 'nccl', 'ucc', 'gloo'], default='mpi',
help='Backend to use for training'
)
args = parser.parse_args()
local_rank = int(os.getenv("LOCAL_RANK", 0))
if args.backend == 'mpi':
local_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
torch.cuda.set_device(local_rank)
print("local_rank", local_rank, "backend", args.backend)
if args.backend == 'mpi':
dist.init_process_group(backend="mpi")
else: # nccl, ucc, gloo
dist.init_process_group(backend=args.backend, init_method="env://")
fsdp_training(local_rank)
dist.destroy_process_group()
Here’s the launcher command I ran.
OMP_NUM_THREADS=1 mpirun -n 2 -- python3 test_fsdp2_mpi.py --backend mpi
Then, I encounter the following error on both rank 0 and 1:
[rank0]: Traceback (most recent call last):
[rank0]: File "/app/pytorch_tests/src/training/test_fsdp2_mpi.py", line 81, in <module>
[rank0]: fsdp_training(local_rank)
[rank0]: File "/app/pytorch_tests/src/training/test_fsdp2_mpi.py", line 52, in fsdp_training
[rank0]: loss.backward()
[rank0]: File "/app/pytorch/torch/_tensor.py", line 648, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/app/pytorch/torch/autograd/__init__.py", line 354, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/app/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 297, in _root_post_backward_final_callback
[rank0]: fsdp_param_group.post_backward()
[rank0]: File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 445, in post_backward
[rank0]: ) = foreach_reduce(
[rank0]: File "/app/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 421, in foreach_reduce
[rank0]: dist.reduce_scatter_tensor(
[rank0]: File "/app/pytorch/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: File "/app/pytorch/torch/distributed/distributed_c10d.py", line 4416, in reduce_scatter_tensor
[rank0]: work.wait()
[rank0]: IndexError: map::at
- OS: Ubuntu22.04
- Pytorch: v2.8.0a0+git6120cc8
- MPI: Open MPI v5.0.7rc2
My question: How to Debug This Effectively?
How do PyTorch developers or contributors typically debug such errors from internal C++ script?
I’m trying to investigate this issue, but since FSDP is triggered from Python and the error happens in the C++ backend, it’s difficult to debug. Debugging with gdb is usually straightforward in pure C++, but at least in my case, I’m having trouble using it effectively in this hybrid case.
Any insight or guidance would be greatly appreciated!
Thanks.