Can't load checkpoint in HSDP, stuck at synchronization in `optim_state_dict_to_load`

Hi, I am trying to load my checkpoint in HSDP (HYBRID_SHARD) mode. Training runs fine, and checkpoint saving runs fine. However there is a problem with checkpoint loading, it blocks forever on optim_state_dict_to_load (or until some long timeout). I am using torch 2.1.2, and the code works fine with FULL_SHARD FSDP.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig

fsdp_state_dict_type = StateDictType.FULL_STATE_DICT
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)

with FSDP.state_dict_type(module,
                                  state_dict_type=fsdp_state_dict_type,
                                  state_dict_config=state_dict_config,
                                  optim_state_dict_config=optim_state_dict_config):
  optim_state_dict = FSDP.optim_state_dict_to_load(  # blocks on this line
                              optim_state_dict=optim_state_dict, model=self.model, optim=optimizer)

I am not the author of the code, this is part of the composer library, right here (In this example, I explicitly expanded the code variables, and context managers for easier understanding).

Since this works in FULL_SHARD, I figured out that there is likely a problem with waiting for processes from different node, which the HSDP should not wait for (sharding with HSDP here is done only on 1 node, not across nodes).

To debug this, I used 2 nodes, each with 4 GPUs.
I printed the following information about process groups

print(f"TOTAL PGs ranks: {str(torch.distributed.distributed_c10d._pg_group_ranks)}")
print(f"TAGS TO PG: {str(torch.distributed.distributed_c10d._tags_to_pg)}")

outputs on node 0

TOTAL PGs ranks: {
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7}, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>: {0: 0, 1: 1, 2: 2, 3: 3}, 
-100: {3: 0, 7: 1}, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>: {0: 0, 4: 1}}

TAGS TO PG: {
'': [
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>], 
'ptd:0': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>], 
'ptd:1': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>], 
'ptd:3': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>]}

outputs on node 1

TOTAL PGs ranks: {
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7},
-100: {3: 0, 7: 1}, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>: {4: 0, 5: 1, 6: 2, 7: 3}, <torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>: {0: 0, 4: 1}}

TAGS TO PG: {
'': [
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>, 
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>], 
'ptd:0': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>], 
'ptd:2': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>], 
'ptd:3': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>]}

I guessed the single node group is ptd:1 for rank 0, and ptd:2 for rank 1, and passed these to optim_state_dict_to_load via pg parameter (if pg is not passed, default group is considered, which was ptd:0. This one contains all 8 GPUs, so “it made sense” that this could 've caused the block). This worked for node 0 (or at least, it didn’t block), but failed on node 2 raising this

with a message that global rank 0 is not part of the group (which is not, as node 1 group has global ranks 4-7).

Any ideas of what could be the issue, or what could I try? I spent days on this one :frowning: .

Sorry that this was a tough bug! Would it be easier to submit an issue to the composer repo? There have been some fixes recently in PyTorch, but I am not sure off the top of my head whether this is related.