Backpropagating through autograd.grad

I’m facing an issue when trying to use the jacobian of a computation as part of the loss function.
I created this minimal example with two networks.

import torch

net1=torch.nn.Sequential( torch.nn.Linear(3,2), torch.nn.Tanh() )
net2=torch.nn.Sequential( torch.nn.Linear(2,1), torch.nn.Tanh() )


jac=torch.autograd.grad(output, hidden, grad_outputs=torch.ones_like(output), create_graph=True, retain_graph=True)[0]


print("net1 grad", net1[0].weight.grad.norm()) #this should be non zero
print("net2 grad", net2[0].weight.grad.norm()) #But how to make this zero?

The jacobian of the output with respect to the hidden is used in the loss function
After calling loss.backward() gradients of net1 and net2 are populated (as expected)
However, I would like to optimize only net1 but keep the weights of net2 frozen.
From my understanding, the gradient flows backwards through torch.autograd.grad and through output and finally into net2
I tried having

output.register_hook(lambda grad: torch.zeros_like(grad))

which from my understaning should set the gradient flowing back from output to zero and therefore net2 should essentially frozen, but this is not the case.

Do you have any suggestions on how to solve this?
Thank you very much for your help!

Just a suggestion, not sure if am right, If you want to optimize only net1, you can pass net1’s parameters to the optimizer. Something like, optimizer=optimizer_name(net1.parameters()). Also, if you don’t want the grads of net2 to be populated, you can set the requires_grad for each parameter to False.

for param in list(net2.parameters()):

Thanks Abhibha for the suggestion.

Indeed that would solve this particular example.
The problem is that I still need net2 to be optimizable. In this minimal example I omitted some parts in the sake of brevity but I have other losses in the network that should behave normally and backpropagate to both net1 and net2.
I just need this one particular loss to not be propagated into net2 as I need the “hidden” tensor to be modified but the non-linear function that net2 computes should not be aware of the jacobian loss.

I guess my confusion is more in the direction of either calling backward() two times, one with the normal loss and and requires grad=True and another time with the jacobian loss and switching to requires grad=False as you suggested.
The other alternative is using hooks but I am not sure if I’m missing something there as using

output.register_hook(lambda grad: torch.zeros_like(grad))

still propagates gradients into net2.

As far as I understand, you can modify the gradients of net2’s parameters by,

for param in list(net2.parameters()):
   param.grad=None #sets grad to None.