What’s the Best Way to Debug FSDP When Hitting a C++ Backend Error?

Hello,

I’m currently working on adding MPI backend support for both FSDP1 and FSDP2 as part of an effort to contribute to PyTorch’s distributed training stack. During testing, I encountered an error in FSDP2 during the backward pass. I would appreciate any advice on how to investigate this issue further — especially regarding debugging techniques for distributed code involving C++ internals.

Environment and Reproduction

Here is the code I used:

code
import os
import time
import argparse
import dataclasses

import torch
from torch import nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
import torch.optim as optim


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.sequential0 = nn.Linear(512, 1024)
        self.sequential1 = nn.Linear(1024, 1024)
        self.sequential2 = nn.Linear(1024, 512)
        self.last = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.sequential0(x))
        x = torch.relu(self.sequential1(x))
        x = torch.relu(self.sequential2(x))
        return self.last(x)


def fsdp_training(local_rank: int):
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device('cuda', local_rank)

    model = SimpleModel().to(device)
    fully_shard(model)

    if rank == 0:
        print(model)

    optimizer = optim.AdamW(model.parameters(), lr=0.01)
    input_data = torch.randn(2, 512).to(device)
    target = torch.randint(0, 10, (2,)).to(device)

    print(f"Rank {rank}/{world_size}: Start training")
    model.train()
    num_epochs = 10
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        output = model(input_data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='FSDP Training')
    parser.add_argument(
        '--backend', type=str, choices=['mpi', 'nccl', 'ucc', 'gloo'], default='mpi',
        help='Backend to use for training'
    )
    args = parser.parse_args()

    local_rank = int(os.getenv("LOCAL_RANK", 0))
    if args.backend == 'mpi':
        local_rank = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
    torch.cuda.set_device(local_rank)

    print("local_rank", local_rank, "backend", args.backend)
    if args.backend == 'mpi':
        dist.init_process_group(backend="mpi")
    else: # nccl, ucc, gloo
        dist.init_process_group(backend=args.backend, init_method="env://")


    fsdp_training(local_rank)
    
    dist.destroy_process_group()

Here’s the launcher command I ran.
OMP_NUM_THREADS=1 mpirun -n 2 -- python3 test_fsdp2_mpi.py --backend mpi

Then, I encounter the following error on both rank 0 and 1:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/app/pytorch_tests/src/training/test_fsdp2_mpi.py", line 81, in <module>
[rank0]:     fsdp_training(local_rank)
[rank0]:   File "/app/pytorch_tests/src/training/test_fsdp2_mpi.py", line 52, in fsdp_training
[rank0]:     loss.backward()
[rank0]:   File "/app/pytorch/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/app/pytorch/torch/autograd/__init__.py", line 354, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/app/pytorch/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 297, in _root_post_backward_final_callback
[rank0]:     fsdp_param_group.post_backward()
[rank0]:   File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 445, in post_backward
[rank0]:     ) = foreach_reduce(
[rank0]:   File "/app/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/pytorch/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 421, in foreach_reduce
[rank0]:     dist.reduce_scatter_tensor(
[rank0]:   File "/app/pytorch/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/app/pytorch/torch/distributed/distributed_c10d.py", line 4416, in reduce_scatter_tensor
[rank0]:     work.wait()
[rank0]: IndexError: map::at
  • OS: Ubuntu22.04
  • Pytorch: v2.8.0a0+git6120cc8
  • MPI: Open MPI v5.0.7rc2

My question: How to Debug This Effectively?

How do PyTorch developers or contributors typically debug such errors from internal C++ script?

I’m trying to investigate this issue, but since FSDP is triggered from Python and the error happens in the C++ backend, it’s difficult to debug. Debugging with gdb is usually straightforward in pure C++, but at least in my case, I’m having trouble using it effectively in this hybrid case.

Any insight or guidance would be greatly appreciated!

Thanks.

I was able to identify the cause of the IndexError using gdb, so I’m sharing the steps here for reference.


1. Run the script directly without mpirun

OMP_NUM_THREADS=1 python3 test_fsdp2_mpi.py --backend mpi

This produced the same IndexError: map::at as when using mpirun.

2. Launch the script under gdb

OMP_NUM_THREADS=1 gdb --args python3 test_fsdp2_mpi.py --backend mpi

Gdb launched correctly, however, the IndexError thrown from the C++ was not caught immediately. It likely did not propagate back to the Python main thread in a way that gdb could catch by default, and it may have been masked or wrapped by the Python runtime or pybind11 interface.

3. Set a catchpoint for all C++ exceptions

(gdb) catch throw
Catchpoint 1 (throw)

This instructs gdb to break whenever any C++ exception is thrown, regardless of whether it is propagated to Python.

As a result, the catchpoint triggered when the exception was thrown:

Thread 4 "python3" hit Catchpoint 1 (exception thrown), 0x00007f308cd914a1 in __cxa_throw()
from /lib/x86_64-linux-gnu/libstdc++.so.6

Then, running bt (backtrace) gave the following:

(gdb) bt
#0 __cxa_throw
#1 std::__throw_out_of_range
#2 std::map::at (c10d::ReduceOp::AVG)
#3 c10d::ProcessGroupMPI::_reduce_scatter_base(...)
#4 c10d::ProcessGroupMPI::runLoop()
...

Note: I also tried using import faulthandler; faulthandler.disable() and import signal; signal.signal(signal.SIGSEGV, signal.SIG_DFL), but it had no effect on catching the exception.

1 Like