What are hooks used for?

Tensors have a function: register_hook.
register_hook ( hook )[SOURCE]

Registers a backward hook.

The description says that everytime a gradient with respect to the tensor is computed the hook will be called.

My question: what are hooks used for?

Kind regards,
Jens

2 Likes

You could pass a function as the hook to register_hook, which will be called every time the gradient is calculated.
This might be useful for debugging purposes, e.g. just printing the gradient or its statistics, or you could of course manipulate the gradient in a custom way, e.g. normalizing it somehow etc.

4 Likes

Is it only used for reporting by the user, or is it also used internally by the back-propagation algorithm?

1 Like

If you manipulate the gradients, the optimizer will use these new custom gradients to update the parameters, so the latter would be true.

6 Likes

I see thanks.

Are they also used if I call backward() on a tensor?

Yes, here is a small example:

x = torch.randn(1, 1)
w = torch.randn(1, 1, requires_grad=True)
w.register_hook(lambda x: print(x))
y = torch.randn(1, 1)

out = x * w
loss = (out - y)**2
loss.register_hook(lambda x: print(x))
loss.mean().backward(gradient=torch.tensor([0.1]))  # prints the gradient in w and loss
14 Likes

Can we pass any parameters of our own to this “hook” function? i mean multiple parameters

Yes, you could pass more parameters to the lambda call:

# same script as above
my_param = "lala"
loss.register_hook(lambda x, my_param=my_param: print(my_param, x))
loss.mean().backward(gradient=torch.tensor(0.1)) 
2 Likes

@ptrblck I was wondering if it is possible to set requires_grad = True for the registered hooks.
More specifically, I am registering hooks for a recurrent network and want to know the gradients of the gradients (second derivative) i.e. Second derivate of the loss w.r.t. each hidden state.
I could not find a method to do this directly.