Calculate second derivative related to preactivations

Let s denote preactivation units (before ReLU, after Linear), then I can get gradients of preactivations dE/ds through register_backward_hook. But how I can get gradients of f(dE/ds) w.r.t. weights, where f is another function (e.g. norm)? It seems that grad_out in the hook does not require grad and does not have a grad_fn.

This is similar to calculating second derivative, except that I’m considering f(dE/ds) instead of f(dE/dW). I know this is somehow abnormal, but I’m trying some new ways related to Hessian matrix. Just wondering whether this is possible…

Thanks

Hi,

First i would advise you not to use hooks on nn.Modules but add them to Tensors: `s.register_hook(your_fn)`. In particular, as you can see in the doc, register_backward_hook is discouraged.

Now in your hook. The gradient will require gradients only if you backardprop with `create_graph=True`.

I tried `s.register_hook(fn)` and `create_graph = True`. Problem is that both input and output of hook function `fn` still have `requires_grad = False` and no `grad_fn`. So I can’t calculate gradient again… (If I register_hook for weights, they do have `requires_grad = True` and `grad_fn`)

I know that we can use `torch.autograd.grad` to calculate higher order gradients for `w`, but preactivation `s` is non-leaf variable. Then `grad(loss, s)` doesn’t have `requires_grad` nor `grad_fn`.

It seems that we can only calculate higher order gradients for leaf variables?

For intermediary results, they need to require gradients for you to be able to compute gradients on this. If they don’t during the forward pass, you need to add it with `.requires_grad_()` before continuing the forward pass.

In the following toy example, I register hooks for each variable, but only hook grad for `x` and `w` print `True` and `s` will print `False`.

``````def hk(x):

x.register_hook(hk)
w.register_hook(hk)

s = torch.mm(x,w)
s.register_hook(hk)

y = s.sum()
y.backward(create_graph=True)
``````

I further set `s.retain_grad()`, but only `x.grad` and `w.grad` have `grad_fn`, and `s.grad` doesn’t.

It seems that I can only get grad values of intermediate variables rather than further processing them.

``````import torch

def get_hook(msg):
def hk(x):
return hk

x.register_hook(get_hook("x"))
w.register_hook(get_hook("w"))

s = torch.mm(x,w)
s.register_hook(get_hook("s"))

y = s.sum()
# use retain_graph=True to be able to call backward multiple times
y.backward(retain_graph=True)
print("All 3 hooks are called and all print False")