The argument to consolidate_state_dict() is the target rank (which is 0 by default). If you want to save the unsharded state dict on all ranks, then you need to loop over all ranks as the target:
for rank in dist.get_world_size():
optimizer.consolidate_state_dict(rank)
torch.save({"optimizer": optimizer.state_dict(), }, "test.pth")
However, that will redundantly save the same state dict from all ranks. If that is not desired, you can do something like:
optimizer.consolidate_state_dict() # default to rank 0
if dist.get_rank() == 0:
torch.save({"optimizer": optimizer.state_dict(), }, "test.pth")
My bad. Every time you call optimizer.consolidate_state_dict(rank), it clears the rank’s previously consolidated state dicts, so the loop for rank in range(dist.get_world_size()) was actually clearing rank 0’s consolidated state dict.
This should work (I checked this one locally):
optimizer.consolidate_state_dict(0)
if dist.get_rank() == 0:
torch.save(...)
Or, if you want to save on all ranks, then you should call torch.save() in the loop (I did not check this one):
for rank in range(dist.get_world_size()):
optimizer.consolidate_state_dict(rank)
if dist.get_rank() == rank:
torch.save(...)