Is it possible to get the Hessian of the loss with respect to the output of a layer via `register_full_backward_hook`

Hi All,

I was just wondering if there’s a possible extension to hooks in order to get the the hessian of the loss with respect to the output a layer. For example, using register_full_backward_hook returns the variable grad_output which represents the derivative of the loss with respect to the module’s output which I’ll denote dL_dsi (for the output, s, of the i-th layer of a network). And, is the same shape as the output of that layer).

I was wondering if it’s at all possible (or indeed even feasible) to take that grad_output and find the derivative of that w.r.t to another grad_output variable. So, in fact, calculate d2L_dsi_dsj where si and sj represent the output of the i-th and j-th layer respectively.

Any suggestions or help would be greatly appreciated!

Thank you! :smiley:

If you want to compute second-order derivatives, I think you mean to find the derivative of your grad_output wrt to another output (not grad_output), because that grad_output would not be part of this grad_output’s graph.

Have you tried registering a full backward hook on a module in which your hook function calls autograd.grad(grad_wrt_this_output, (other_output,))?

Hey!

Yes, that’s one way of saying it. But essentially compute the gradient of the loss with respect 2 grad_output of the i-th and j-th layer. Which might make more sense as the derivative of the grad_output with respect to another grad_output (in terms of calling autograd)

I guess a further would be how exactly could I get these derivatives efficiently? My current approach of doing it, is having this within an Optimizer (as my Optimizer requires the use of register_full_backward_hook).

I store all these backward hooks within a defaultdict of torch.optim.Optimizer (the state attribute to be preceise). Would it be as simple as iterating through that defaultdict within a nested-loop as just calling torch.autograd.grad on all pairs? (Like you’ve stated!)

Although, my question is how exactly could these two tensors be connected if no graph is created to represent a derivative of it? (In a similar way to compute second derivatives you need to enable create_graph=True in torch.autograd.grad, for example)

Hi @soulitzer,

I tried what you suggested with applying torch.autograd.grad, however, it throws an error as the grad_output doesn’t have a grad_fn.

I did do this outside of the hook. So,

  def _hessian_hook_check(self):
    all_grad_outputs = self._get_all_grad_outputs() #grabs all grad_outputs

    d1=all_grad_outputs.copy()
    d2=all_grad_outputs.copy()

    for keyi in d1:
      for keyj in d2:
        d2L_dsi_dsj, = torch.autograd.grad(d2[keyi], d2[keyj])
        print("d2L_dsi_dsj: ",d2L_dsi_dsj)

with the following error,

Traceback (most recent call last):
  File "main.py", line 169, in <module>
    optim.step(loss=energy_loss)    
  File "~/.local/lib/python3.8/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/main.py", line 590, in step
    self._hessian_hook_check()
  File "~/main.py", line 601, in _hessian_hook_check
    d2L_dsi_dsj, = torch.autograd.grad(d2[keyi], d2[keyj])
  File "~/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Now I assume this is what you mean here? That I should call the calculation of derivative of grad_output for a given module with respect to the other modules within the backward hook itself, rather than calling within a different method? But then how exactly could I compute the gradient of the grad_output of the final layer with respect to the grad_output of 2nd to last layer if it hasn’t been created yet? (Unless it’s equal to 0 by some definition). As I’d be doing this calculation within the hook of the final layer before I’ve called the hook of the 2nd to last layer?

Thank you!

Hmm, could you just do torch.autograd.grad(loss, (out1, out2, ... ), create_graph=True) where the outs are the outputs of each of the layers. I don’t see why you need to do this with hooks.

Thinking about it more, I don’t necessarily need to do it via hooks. It’s just that I use the hooks to reconstruct the gradients of a given layer for all samples so I thought it might be easily to do the same for the Hessian.

When doing this, that would give the gradient of the loss with respect to the output of each layer (for all input samples?), right? If I wanted the Hessian, I assume I could do a similar thing as I stated above as do a nested-loop to get the Hessian for different gradients?,

first_grads = torch.autograd.grad(loss, (out1, out2, ... ), create_graph=True)

hessian_terms = []

for first_grad1 in first_grads:
  for first_grad2 in first_grads:
    hessian_terms.append( torch.autograd.grad(first_grad1, first_grad2, create_graph=False)[0] )

Is this something that’s feasible?

hmm, wouldn’t it be something like:

for first_grad in first_grads:
  for out in outs:
     hessian_terms.append(...)

If you did:

for first_grad1 in first_grads:
  for first_grad2 in first_grads:

You didn’t use the dloss_out1 to compute dloss_out2, so it wouldn’t be part of your graph and you would probably error there.

Hi @soulitzer,

I’m a little bit confused about how I can get these values. I wrote the following script yesterday to test what you stated above but my attempts throws an error and I’m not 100% sure as to why it fails.

import torch
import torch.nn as nn

class network(nn.Module):

  def __init__(self):
    super(network, self).__init__()
    
    self.fc1 = nn.Linear(2, 32, bias=True)
    self.fc2 = nn.Linear(32,32, bias=True)
    self.fc3 = nn.Linear(32, 2, bias=True)
    
  def forward(self, x):
    self.x1 = self.fc1(x)
    self.x2 = self.fc2(self.x1)
    self.x3 = self.fc3(self.x2)
    return self.x3
    
net = network()

x = torch.randn(4096, 2)

y = net(x)

loss = y.sum(dim=-1).mean()

xis = [net.x1, net.x2, net.x3]

for xi in xis:
  grad, = torch.autograd.grad(loss, xi, torch.ones_like(loss), retain_graph=True, create_graph=True)
  for xj in xis: 
    gradgrad, = torch.autograd.grad(grad, xj, torch.ones_like(grad), allow_unused=True)

The error message is,

Traceback (most recent call last):
  File "print_hessian.py", line 32, in <module>
    gradgrad, = torch.autograd.grad(grad, xj, torch.ones_like(grad), allow_unused=True)
  File "~/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Now I’d assume this is the case because each of the xis don’t have a grad_fn? Is that correct?

What might be easier to tell you exactly why I wanted the grad_output terms differientated. I’m trying to get the Hessian of my loss with respect to all parameters for all input samples. So, in effect have a Tensor of size [B,N,N] where B's the batch size and N are the number of parameters within the network. (Although, this could be reduced to do it within given pairs of layers so N being Ni + Nj for the i-th and j-th layer.)

Thank you for your help!

I think the problem is that the grad of loss wrt the last layer x3 is constant wrt to any of the inputs that require grad, because all you do is a mean().

If you are computing Hessian of loss w.r.t. to parameters, wouldn’t the size be eventually be [B, X, N, N] where X is the numel of the parameter, assuming all parameters have the same numel. Also, if two parameters a, b had different sizes i.e., weight and bias, d2L_dadb would have size of a but d2L_dbda would have size b, so that may confuse things.

From the above though, we seem to be computing Hessian of loss w.r.t. to each of the layer outputs, which is a different quantity.

yes, that’s true. Especially, if different layers have different sizes. I was reading through the following and see like it might not even be possible to calculate the Hessian w.r.t all the parameters for all the samples? torch.autograd.functional.* for models · Issue #40480 · pytorch/pytorch · GitHub