What are hooks used for?

Tensors have a function: register_hook.
register_hook ( hook )[SOURCE]

Registers a backward hook.

The description says that everytime a gradient with respect to the tensor is computed the hook will be called.

My question: what are hooks used for?

Kind regards,
Jens

1 Like

You could pass a function as the hook to register_hook, which will be called every time the gradient is calculated.
This might be useful for debugging purposes, e.g. just printing the gradient or its statistics, or you could of course manipulate the gradient in a custom way, e.g. normalizing it somehow etc.

2 Likes

Is it only used for reporting by the user, or is it also used internally by the back-propagation algorithm?

If you manipulate the gradients, the optimizer will use these new custom gradients to update the parameters, so the latter would be true.

3 Likes

I see thanks.

Are they also used if I call backward() on a tensor?

Yes, here is a small example:

x = torch.randn(1, 1)
w = torch.randn(1, 1, requires_grad=True)
w.register_hook(lambda x: print(x))
y = torch.randn(1, 1)

out = x * w
loss = (out - y)**2
loss.register_hook(lambda x: print(x))
loss.mean().backward(gradient=torch.tensor([0.1]))  # prints the gradient in w and loss
8 Likes

Can we pass any parameters of our own to this “hook” function? i mean multiple parameters

Yes, you could pass more parameters to the lambda call:

# same script as above
my_param = "lala"
loss.register_hook(lambda x, my_param=my_param: print(my_param, x))
loss.mean().backward(gradient=torch.tensor(0.1)) 
1 Like