Hook function and list of losses with autograd

Hello everyone ! =)

import torch

# Define hook functions
value = 0
def hook_fn1(grad):
    global value
    print("Hook 1 called")
    value += 1
def hook_fn2(grad):
    print("Hook 2 called")

# Create tensors with requires_grad=True
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=True)

# Register hook_fn1 on the first tensor
x.register_hook(hook_fn1)

# Register hook_fn2 on the second tensor
y.register_hook(hook_fn2)

# Define loss1 as the sum of x.sum() and y.mean()
loss1 = x.sum()
loss2 = x.sum() + y.mean()
# Use torch.autograd to compute gradients for loss1
torch.autograd.backward([loss1,loss2])
# Both hooks registered on x and y will be called
print(value)

I ran this code expecting that it should go through twice the hook_fn1 but it only goes once. I am not sure to understand why.
How can I adapt my code to make it go through for each? I don’t want to use a for loop because for big network it would take forever

Hi! Running this code snippet shouldn’t run any of the hooks, actually, since the hooks are called in the backward and there are no .backward() calls in the snippet.

If you are running the following, you should get that hook_fn1 is called twice (value=2) while hook_fn2 is called just once.

loss1.backward()
loss2.backward()

Is this the behavior you’re seeing?

Thanks for your answer but you missed the backward in my code

torch.autograd.backward([loss1,loss2])

this is where the bacward is. The point is I want to avoid using multiple backward and use the pytorch implementation

Welp LOL I just realized I could scroll down–sorry about that. The way the torch.autograd.backward API works is that it’ll sum loss1 and loss2 and then do a backward on that, which is why you observe your current behavior vs having value be 2.

I am curious what your use case is, as the closest API that comes to mind that would compute multiple forwards/backwards at a time would be per-sample gradients Per-sample-gradients — PyTorch Tutorials 2.0.1+cu117 documentation but that API assumes a very particular structure where all the losses you are computing are pointwise. Here it looks like you have some dependency structure, so it may not be helpful.

Basically what backward does is compute per summed batch gradient

grad = sum_batch (dl_batch / d wik) but I would like to compute the squared version
grad_squared = sum_batch (dl_batch / d wik) ^2.

Sorry could not find the equations