How can I save grad attributes within a network's state_dict?

Hi All,

I was wondering if it were possible to save the grad attributes of my model for all my parameters? I’m currently using a custom optimizer in my problem and it requires using the .grad attribute (as I’m preconditioning gradients).

The way I save my model is via,

    torch.save({'epoch':epoch,
                'model_state_dict':net.state_dict(),
                'optim_state_dict':optim.state_dict(),
                'loss':loss}, model_path)

and is subsequently loaded via,

state_dict = torch.load(f=model_path, map_location=torch.device('cpu'))
net.load_state_dict(state_dict['model_state_dict']) #returns <All keys matched successfully>
optim.load_state_dict(state_dict['optim_state_dict']) #return nothing (I assume ok?)

If I load the model and print the grad attribute it doesn’t exist and hence causes my code to crash. So, I was wondering how exactly can I save these values? I printed the grad attribute via,

for name, param in net.named_parameters():
  print(param.grad)

The strange thing is I called loss.backward() before calling optim.step() so I thought that the grad attributes would be filled with a grad attribute? This is done in the standard way like,

optim.zero_grad()
loss_mean.backward()
optim.step()

Any help would be appreicated! Thank you! :slight_smile:

The state_dict doesn’t contain the .grad attributes of the parameters, but you could try to store a custom dict with them and reassign these gradients after loading the state_dict to all needed parameters.

EDIT: state_dict hooks might also be an valuable workflow you could take a look into.

Hi @ptrblck!

One thing I’ve just changed which seems to have solved the missing grad attributes issue is by changing how I defined the shape. So for example, when cycling through optim.param_groups I got the shape via the code below.

module = group['module']
g = module.weight.grad
s = g.shape

However, I’ve changed it to,

g = module['params'][0].grad
s = g.shape

which seems to work which I assume because I’m accessing it via the parameters rather than the module, and the grad attributes are filled when calling loss.backward()?

Thank you!

I’m unsure where you are assigning this shape to and how it relates to the state_dict. Could you explain this workflow a bit more, please?

1 Like

Ok, so I’m making a custom optimizer based on KFAC and I noticed that if I resume my model from a checkpoint it fails because the grad attributes don’t exist. When I first initialise the optimizer, I append each module to a list and apply a forward and backward hook to each module. I also store the parameters of modules into the params list. This is done via,

  def _prepare_model(self):
    print("Adding parameters to module, marking nn.Linear layers!")
    for module in self.net.modules():
      if(module.__class__.__name__ in self._kfac_accepted_modules):
        print("module: ",module)
        self.modules.append(module)
        
        handle = module.register_forward_pre_hook(self._save_input)
        self._fwd_handles.append(handle)
        handle = module.register_full_backward_hook(self._save_output)
        self._bwk_handles.append(handle)
        
        params = [module.weight]
        if(module.bias is not None):
          params.append(module.bias)
        d = {'params':params, 'module':module}
        self.params.append(d)

Then when I go to precondition the gradient attribute I grab the shape of it via group['module'].weight.grad.shape. This didn’t work, but what did work was to change group['module'].weight.grad to group['params'][0].grad which does exist and gets created when loss.backward() gets called. This is done via the following, and is called on each group within optim.param_groups

def _precondition(self, group):
  module = group['module']

  g = group['params'][0].grad
  s = g.shape
  if(module.bias is not None):
    #gb = module.bias.grad       #doesn't work
    gb = group['params'][1].grad #does work
    g = torch.cat([g, gb.unsqueeze(1)], dim=1)

Does this clarify the workflow? Thank you!

1 Like

I hope that means my solution is correct! :wink:

I was wondering if I could ask a follow up question? The register_full_backward_hook if I call a double backward on my network does the backward_hook recorded a double backward or is it just a single backward? And, if I call 2 in a row (i.e. a double backward with one loss function, then a single backward with a different loss function). Which output is stored?

my forward/backward hooks are defined like so,

  def _save_input(self, module, input):
    self.state[module]['a'] = input[0]
    
  def _save_output(self, module, grad_input, grad_output):
    self.state[module]['gl'] = grad_output[0]

So, I’d assume whenever a backward call is called, it overwrites the previous call?

Thank you for all the help! :slight_smile: