Hi, I am trying to load my checkpoint in HSDP (HYBRID_SHARD) mode. Training runs fine, and checkpoint saving runs fine. However there is a problem with checkpoint loading, it blocks forever on optim_state_dict_to_load
(or until some long timeout). I am using torch 2.1.2, and the code works fine with FULL_SHARD FSDP.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig
fsdp_state_dict_type = StateDictType.FULL_STATE_DICT
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(module,
state_dict_type=fsdp_state_dict_type,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config):
optim_state_dict = FSDP.optim_state_dict_to_load( # blocks on this line
optim_state_dict=optim_state_dict, model=self.model, optim=optimizer)
I am not the author of the code, this is part of the composer library, right here (In this example, I explicitly expanded the code variables, and context managers for easier understanding).
Since this works in FULL_SHARD, I figured out that there is likely a problem with waiting for processes from different node, which the HSDP should not wait for (sharding with HSDP here is done only on 1 node, not across nodes).
To debug this, I used 2 nodes, each with 4 GPUs.
I printed the following information about process groups
print(f"TOTAL PGs ranks: {str(torch.distributed.distributed_c10d._pg_group_ranks)}")
print(f"TAGS TO PG: {str(torch.distributed.distributed_c10d._tags_to_pg)}")
outputs on node 0
TOTAL PGs ranks: {
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7},
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>: {0: 0, 1: 1, 2: 2, 3: 3},
-100: {3: 0, 7: 1},
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>: {0: 0, 4: 1}}
TAGS TO PG: {
'': [
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>,
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>,
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>],
'ptd:0': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971ac009f0>],
'ptd:1': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc2d5b0>],
'ptd:3': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b975cc574f0>]}
outputs on node 1
TOTAL PGs ranks: {
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7},
-100: {3: 0, 7: 1},
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>: {4: 0, 5: 1, 6: 2, 7: 3}, <torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>: {0: 0, 4: 1}}
TAGS TO PG: {
'': [
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>,
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>,
<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>],
'ptd:0': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b970d06e6b0>],
'ptd:2': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bcf5670>],
'ptd:3': [<torch.distributed.distributed_c10d.ProcessGroup object at 0x2b971bdda5b0>]}
I guessed the single node group is ptd:1
for rank 0, and ptd:2
for rank 1, and passed these to optim_state_dict_to_load
via pg
parameter (if pg
is not passed, default group is considered, which was ptd:0
. This one contains all 8 GPUs, so “it made sense” that this could 've caused the block). This worked for node 0 (or at least, it didn’t block), but failed on node 2 raising this
with a message that global rank 0 is not part of the group (which is not, as node 1 group has global ranks 4-7).
Any ideas of what could be the issue, or what could I try? I spent days on this one .