Hi, I have found powerSGD as very helpful in my training scenario where the bandwidth has been playing a limiting factor in training speed and powerSGD works very well in reducing the time while maintaining performance.
However, it is my understanding that powerSGD is a stateful hook, and hence if I somehow pause the training and resume later, I need to save the state as well, and could anyone show me how to do that? I’ve been playing around with it and it seems that this would work:
class MyOwnPowerSGDState(PowerSGDState):
def state_dict(self):
return {
k: getattr(self, k) for k in self.__slots__
}
def load_state_dict(self, state_dict: Dict[str, Any]):
for k in self.__slots__:
setattr(self, k, state_dict[k])
However, I have not been able to resume the run since it gives me some shape error for the tensors. Could anyone help me with this? Thanks!
I don’t believe that there’s a good solution here.
I managed to get the following to work.
Say you’re doing it in the following way:
import torch.distributed.algorithms.ddp_comm_hooks as hooks
state = hooks.powerSGD.PowerSGDState(dist.group.WORLD)
register_ddp_comm_hook(DDPCommHookType.POWER_SGD, state)
The way to add state to your state dict is the following:
state: PowerSGDState = ...
pg = state.process_group
state.process_group = None
state_dict = {
'powerSGD_state': state,
'model' = ....
}
torch.save(state_dict, "/tmp/my_model")
#NB restore pg in `state` otherwise you won't be able to keep training
state.process_group = pg
The when loading, you’d do this instead:
state_dict = torch.load("/tmp/my_model")
state = state_dict['powerSGD_state']
state = dist.group.WORLD
register_ddp_comm_hook(DDPCommHookType.POWER_SGD, state)
I believe @Rohan_Varma might have a better idea on how to address this issue.