Torch.distributed.checkpoint.save hangs while writing the .metadata file

While saving a distributed checkpoint to disk with torch.distributed.checkpoint.save (across 512 ranks), the model parts are successfully saved in __x_y.distcp files, but occasionally the process group hangs while writing the .metadata file and times out, causing the whole job to fail. The part of the code that saves the checkpoint is straightforward and looks something like this (simplifying):

import torch.distributed as dist
import torch.distributed.checkpoint as dcp

pg = dist.new_group(backend="gloo")
...
dcp.save(states, checkpoint_id=checkpoint_id, process_group=pg)

Exact replication may be difficult, since I’m using MI250X GPUs (ugh) and even on my end the failure is stochastic: maybe half the time dcp.save() succeeds without an issue and half the time the process group hangs as described above.

I’m just trying to understand why the .distcp files are always successfully saved, and the process group seems to hang only when writing the .metadata file. Any pointers and suggestions for potential workarounds would be much appreciated.

My torch version is 2.6.0.dev20241005+rocm6.2, for what it’s worth.

dist.new_group(backend="gloo") creates a connection mesh between the ranks.
Do you know what transport is being used by gloo under the covers for the 2 cases?
Is it TCP or RDMA or both?
Gloo has support for both.

You said that when it hangs, you get a TimeoutError. Would you be able to post the full output when this occurs?

I am also curious about the consistency of TimeoutError’s if you were to just “try again” for the same job.

There’s likely a better implementation, but something like this:

try:
    dcp.save(states, checkpoint_id=checkpoint_id, process_group=pg)
except TimeoutError:
   dcp.save(states, checkpoint_id=checkpoint_id, process_group=pg)

Also, I am currently looking around here in the source code, since it just fails at writing metadata.