ZERO optimizer.consolidate_state_dict() will hang

My bad. Every time you call optimizer.consolidate_state_dict(rank), it clears the rank’s previously consolidated state dicts, so the loop for rank in range(dist.get_world_size()) was actually clearing rank 0’s consolidated state dict.

This should work (I checked this one locally):

optimizer.consolidate_state_dict(0)
if dist.get_rank() == 0:
    torch.save(...)

Or, if you want to save on all ranks, then you should call torch.save() in the loop (I did not check this one):

for rank in range(dist.get_world_size()):
    optimizer.consolidate_state_dict(rank)
    if dist.get_rank() == rank:
        torch.save(...)
1 Like