My bad. Every time you call optimizer.consolidate_state_dict(rank)
, it clears the rank’s previously consolidated state dicts, so the loop for rank in range(dist.get_world_size())
was actually clearing rank 0’s consolidated state dict.
This should work (I checked this one locally):
optimizer.consolidate_state_dict(0)
if dist.get_rank() == 0:
torch.save(...)
Or, if you want to save on all ranks, then you should call torch.save()
in the loop (I did not check this one):
for rank in range(dist.get_world_size()):
optimizer.consolidate_state_dict(rank)
if dist.get_rank() == rank:
torch.save(...)