Let s denote preactivation units (before ReLU, after Linear), then I can get gradients of preactivations dE/ds through register_backward_hook. But how I can get gradients of f(dE/ds) w.r.t. weights, where f is another function (e.g. norm)? It seems that grad_out in the hook does not require grad and does not have a grad_fn.

This is similar to calculating second derivative, except that I’m considering f(dE/ds) instead of f(dE/dW). I know this is somehow abnormal, but I’m trying some new ways related to Hessian matrix. Just wondering whether this is possible…

First i would advise you not to use hooks on nn.Modules but add them to Tensors: s.register_hook(your_fn). In particular, as you can see in the doc, register_backward_hook is discouraged.

Now in your hook. The gradient will require gradients only if you backardprop with create_graph=True.

I tried s.register_hook(fn) and create_graph = True. Problem is that both input and output of hook function fn still have requires_grad = False and no grad_fn. So I can’t calculate gradient again… (If I register_hook for weights, they do have requires_grad = True and grad_fn)

I know that we can use torch.autograd.grad to calculate higher order gradients for w, but preactivation s is non-leaf variable. Then grad(loss, s) doesn’t have requires_grad nor grad_fn.

It seems that we can only calculate higher order gradients for leaf variables?

For intermediary results, they need to require gradients for you to be able to compute gradients on this. If they don’t during the forward pass, you need to add it with .requires_grad_() before continuing the forward pass.

import torch
def get_hook(msg):
def hk(x):
print("Does {}'s grad requires grad?".format(msg))
print(x.requires_grad)
return hk
x = torch.randn(16, 5, requires_grad=True)
x.register_hook(get_hook("x"))
w = torch.randn(5, 3, requires_grad=True)
w.register_hook(get_hook("w"))
s = torch.mm(x,w)
s.register_hook(get_hook("s"))
y = s.sum()
# use retain_graph=True to be able to call backward multiple times
y.backward(retain_graph=True)
print("All 3 hooks are called and all print False")
# Asking for graph to be created and set the grad output to requires_grad=True
go = torch.ones((), requires_grad=True)
y.backward(go, create_graph=True, retain_graph=True)
print("All 3 hooks are called and all print True")
# Asking for graph to be created but the grad output has requires_grad=False
y.backward(create_graph=True, retain_graph=True)
print("All 3 hooks are called and s print False but x,w print True")
# This is because go does not require gradients, so the gradient of s doesn't either
# But in the mm, x is used to compute the gradient of w, so the gradient of w requires gradients.