Hi All,
I was wondering if it were possible to save the grad attributes of my model for all my parameters? I’m currently using a custom optimizer in my problem and it requires using the .grad
attribute (as I’m preconditioning gradients).
The way I save my model is via,
torch.save({'epoch':epoch,
'model_state_dict':net.state_dict(),
'optim_state_dict':optim.state_dict(),
'loss':loss}, model_path)
and is subsequently loaded via,
state_dict = torch.load(f=model_path, map_location=torch.device('cpu'))
net.load_state_dict(state_dict['model_state_dict']) #returns <All keys matched successfully>
optim.load_state_dict(state_dict['optim_state_dict']) #return nothing (I assume ok?)
If I load the model and print the grad attribute it doesn’t exist and hence causes my code to crash. So, I was wondering how exactly can I save these values? I printed the grad attribute via,
for name, param in net.named_parameters():
print(param.grad)
The strange thing is I called loss.backward()
before calling optim.step()
so I thought that the grad attributes would be filled with a grad
attribute? This is done in the standard way like,
optim.zero_grad()
loss_mean.backward()
optim.step()
Any help would be appreicated! Thank you! 
The state_dict
doesn’t contain the .grad
attributes of the parameters, but you could try to store a custom dict
with them and reassign these gradients after loading the state_dict
to all needed parameters.
EDIT: state_dict
hooks might also be an valuable workflow you could take a look into.
Hi @ptrblck!
One thing I’ve just changed which seems to have solved the missing grad
attributes issue is by changing how I defined the shape. So for example, when cycling through optim.param_groups
I got the shape via the code below.
module = group['module']
g = module.weight.grad
s = g.shape
However, I’ve changed it to,
g = module['params'][0].grad
s = g.shape
which seems to work which I assume because I’m accessing it via the parameters rather than the module, and the grad
attributes are filled when calling loss.backward()
?
Thank you!
I’m unsure where you are assigning this shape to and how it relates to the state_dict
. Could you explain this workflow a bit more, please?
1 Like
Ok, so I’m making a custom optimizer based on KFAC and I noticed that if I resume my model from a checkpoint it fails because the grad
attributes don’t exist. When I first initialise the optimizer, I append each module to a list and apply a forward and backward hook to each module. I also store the parameters of modules into the params
list. This is done via,
def _prepare_model(self):
print("Adding parameters to module, marking nn.Linear layers!")
for module in self.net.modules():
if(module.__class__.__name__ in self._kfac_accepted_modules):
print("module: ",module)
self.modules.append(module)
handle = module.register_forward_pre_hook(self._save_input)
self._fwd_handles.append(handle)
handle = module.register_full_backward_hook(self._save_output)
self._bwk_handles.append(handle)
params = [module.weight]
if(module.bias is not None):
params.append(module.bias)
d = {'params':params, 'module':module}
self.params.append(d)
Then when I go to precondition the gradient attribute I grab the shape of it via group['module'].weight.grad.shape
. This didn’t work, but what did work was to change group['module'].weight.grad
to group['params'][0].grad
which does exist and gets created when loss.backward()
gets called. This is done via the following, and is called on each group within optim.param_groups
def _precondition(self, group):
module = group['module']
g = group['params'][0].grad
s = g.shape
if(module.bias is not None):
#gb = module.bias.grad #doesn't work
gb = group['params'][1].grad #does work
g = torch.cat([g, gb.unsqueeze(1)], dim=1)
Does this clarify the workflow? Thank you!
1 Like
I hope that means my solution is correct! 
I was wondering if I could ask a follow up question? The register_full_backward_hook
if I call a double backward on my network does the backward_hook
recorded a double backward or is it just a single backward? And, if I call 2 in a row (i.e. a double backward with one loss function, then a single backward with a different loss function). Which output is stored?
my forward/backward hooks are defined like so,
def _save_input(self, module, input):
self.state[module]['a'] = input[0]
def _save_output(self, module, grad_input, grad_output):
self.state[module]['gl'] = grad_output[0]
So, I’d assume whenever a backward call is called, it overwrites the previous call?
Thank you for all the help! 