Hey,

so I’m working on a problem where I have multiple losses and I want to stop backpropagation of one of the losses before the end of the computational graph. In more detail:

I have a tensor x (y_0 in my second post) which has already been passed through a couple of pytroch functions, so the grad_fn attribute is not None. This Tensor x is then passed through a subnetwork y = f(x; theta). y is then used in a direct loss L(y) and also passed on to another subnetwork. Now, I want that the direct loss will only affect the parameters theta in f and should not affect any parameters prior to f, ie. backpropagation should be stopped at x for the direct loss L. However, the loss from the second subnetwork should backpropagate the entire computational graph, therefore detaching x is not possible.

I could, ofc, clone x, detach the clone and make a second pass through f but I would like to avoid that for efficiency reasons. Is there a smarter way to do this?

Thank you for your help!

I’m not sure if I fully understand the network structure but would explicitly computing gradients for the subnetwork via e.g. `torch.autograd.grad`

or by providing `inputs`

to `backward`

work?

Hey ptrblck,

thanks for your fast response!

Here is an image of the network layout:

f_0, f_1 and f_2 are all NNs with parameters phi, theta and psi. the second network f_1 has an additional loss L_1 but the gradients should only affect network f_1. Thus, the gradients of L_1 wrt. to f_0 should be 0. If I would just call backward naively on L_1, I would get non-zero contribution to the total gradients of phi. This is what I want to avoid.

If I understand `torch.autograd.grad`

correctly, it would do the same as cloning and detaching y_0 and then using the cloned y_0 for the loss L_1, is that correct?

If I understand `torch.autograd.grad`

correctly, it would do the same as cloning and detaching y_0 and then using the cloned y_0 for the loss L_1, is that correct?

yes, but the difference is that now the .grad fields won’t be populated. Instead .grad() returns the gradients.

you can also pass inputs= to .backward(), it does the same thing except it also populates the .grad fields. (as ptrblck mentions)