Torch.distributed.checkpoint.save hangs while writing the .metadata file

While saving a distributed checkpoint to disk with torch.distributed.checkpoint.save (across 512 ranks), the model parts are successfully saved in __x_y.distcp files, but occasionally the process group hangs while writing the .metadata file and times out, causing the whole job to fail. The part of the code that saves the checkpoint is straightforward and looks something like this (simplifying):

import torch.distributed as dist
import torch.distributed.checkpoint as dcp

pg = dist.new_group(backend="gloo")
...
dcp.save(states, checkpoint_id=checkpoint_id, process_group=pg)

Exact replication may be difficult, since I’m using MI250X GPUs (ugh) and even on my end the failure is stochastic: maybe half the time dcp.save() succeeds without an issue and half the time the process group hangs as described above.

I’m just trying to understand why the .distcp files are always successfully saved, and the process group seems to hang only when writing the .metadata file. Any pointers and suggestions for potential workarounds would be much appreciated.

My torch version is 2.6.0.dev20241005+rocm6.2, for what it’s worth.

dist.new_group(backend="gloo") creates a connection mesh between the ranks.
Do you know what transport is being used by gloo under the covers for the 2 cases?
Is it TCP or RDMA or both?
Gloo has support for both.