About torch.autograd.grad

Hi, I have a quick question about torch.autograd.grad function.

First, let me explain my situation.
My model consists of 2 consecutive networks, A & B, where the output of A (output_A) directly becomes the input of B.
After calculating the loss function using the outputs of B, my model computes the gradient of the loss with respect to the output of A to obtain grad-cam, by performing grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach().
At this stage, my model modifies the output_A using grad cam (denote output_A') and feeds the modified output_A' to model B again and computes the same loss once again.
Finally, both of the loss computed from output_A and the loss' computed from output_A' are backpropagated to the model using (loss+loss').backward() and optimizer.step().

By implementing this way, the model works and I’m able to train the model.
But my question is that, is it ok to perform grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach() during forward pass and then backprop the final (loss+loss').backward()?
In other words, doesn’t grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach() harm the loss.backward() part when computing (loss+loss').backward()?
I think it is ok since torch.autograd.grad function just manually computes the gradients, but I just want to be sure about it.

Thank you for your consideration in advance :slight_smile:

.detach() will “return” (rather than modifying inplace) a tensor that is detached from the computation graph. This will break the graph and might cause error/wrong gradient calculations during the backward pass.

Thanks for your reply!

However, aren’t grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach() and (loss+loss').backward() separate procedure?
Does the computed gradient information still left in the computation graph after calling torch.autograd.grad(loss, output_A, retain_graph=True)?

Hmm, I will try without detach() and update my result after comparing both cases (w/ or w/o detach()).

Right! For you use-case, using detach() should not cause any errors.

Sounds good! Feel free to post if you see any errors.

After conducting the experiments for both cases, the difference in the results is quite negligible. It seems detach() does not affect the learning procedure for my case.