How to resume with powerSGD-enabled training?

Hi, I have found powerSGD as very helpful in my training scenario where the bandwidth has been playing a limiting factor in training speed and powerSGD works very well in reducing the time while maintaining performance.

However, it is my understanding that powerSGD is a stateful hook, and hence if I somehow pause the training and resume later, I need to save the state as well, and could anyone show me how to do that? I’ve been playing around with it and it seems that this would work:

class MyOwnPowerSGDState(PowerSGDState):
    def state_dict(self):
        return {
            k: getattr(self, k) for k in self.__slots__
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        for k in self.__slots__:
            setattr(self, k, state_dict[k])

However, I have not been able to resume the run since it gives me some shape error for the tensors. Could anyone help me with this? Thanks!

I don’t believe that there’s a good solution here.

I managed to get the following to work.

Say you’re doing it in the following way:

import torch.distributed.algorithms.ddp_comm_hooks as hooks

state = hooks.powerSGD.PowerSGDState(dist.group.WORLD)
register_ddp_comm_hook(DDPCommHookType.POWER_SGD, state)

The way to add state to your state dict is the following:

state: PowerSGDState = ...
pg = state.process_group
state.process_group = None

state_dict = {
   'powerSGD_state':  state,
   'model' = ....
}
torch.save(state_dict, "/tmp/my_model")

#NB restore pg in `state` otherwise you won't be able to keep training
state.process_group = pg

The when loading, you’d do this instead:

state_dict = torch.load("/tmp/my_model")
state = state_dict['powerSGD_state']
state = dist.group.WORLD
register_ddp_comm_hook(DDPCommHookType.POWER_SGD, state)

I believe @Rohan_Varma might have a better idea on how to address this issue.

Hi, thanks for this. So you are saying that I have to detach the process groups before saving and reattach after loading?

Yes, you have to do so because it’s part of the PowerSGDState object and it’s not serializable.

This is a shortcoming of the current design. I filed a feature request to explore addressing it: checkpointing of DDP comms hook · Issue #75666 · pytorch/pytorch · GitHub

2 Likes

Thanks Rodrigo! Agreed that we don’t have a great way to fix this at the moment, we’ll track the fix in checkpointing of DDP comms hook · Issue #75666 · pytorch/pytorch · GitHub

Great! Thanks for the help!