Limit backpropagation depth for multiple backward passes

Hey,

so I’m working on a problem where I have multiple losses and I want to stop backpropagation of one of the losses before the end of the computational graph. In more detail:

I have a tensor x (y_0 in my second post) which has already been passed through a couple of pytroch functions, so the grad_fn attribute is not None. This Tensor x is then passed through a subnetwork y = f(x; theta). y is then used in a direct loss L(y) and also passed on to another subnetwork. Now, I want that the direct loss will only affect the parameters theta in f and should not affect any parameters prior to f, ie. backpropagation should be stopped at x for the direct loss L. However, the loss from the second subnetwork should backpropagate the entire computational graph, therefore detaching x is not possible.

I could, ofc, clone x, detach the clone and make a second pass through f but I would like to avoid that for efficiency reasons. Is there a smarter way to do this?

Thank you for your help!

I’m not sure if I fully understand the network structure but would explicitly computing gradients for the subnetwork via e.g. torch.autograd.grad or by providing inputs to backward work?

Hey ptrblck,

thanks for your fast response!
Here is an image of the network layout:

f_0, f_1 and f_2 are all NNs with parameters phi, theta and psi. the second network f_1 has an additional loss L_1 but the gradients should only affect network f_1. Thus, the gradients of L_1 wrt. to f_0 should be 0. If I would just call backward naively on L_1, I would get non-zero contribution to the total gradients of phi. This is what I want to avoid.

If I understand torch.autograd.grad correctly, it would do the same as cloning and detaching y_0 and then using the cloned y_0 for the loss L_1, is that correct?

If I understand torch.autograd.grad correctly, it would do the same as cloning and detaching y_0 and then using the cloned y_0 for the loss L_1, is that correct?

yes, but the difference is that now the .grad fields won’t be populated. Instead .grad() returns the gradients.

you can also pass inputs= to .backward(), it does the same thing except it also populates the .grad fields. (as ptrblck mentions)