Hi, I have a quick question about torch.autograd.grad
function.
First, let me explain my situation.
My model consists of 2 consecutive networks, A & B, where the output of A (output_A
) directly becomes the input of B.
After calculating the loss function using the outputs of B, my model computes the gradient of the loss with respect to the output of A to obtain grad-cam, by performing grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach()
.
At this stage, my model modifies the output_A
using grad cam (denote output_A'
) and feeds the modified output_A'
to model B again and computes the same loss once again.
Finally, both of the loss
computed from output_A
and the loss'
computed from output_A'
are backpropagated to the model using (loss+loss').backward()
and optimizer.step()
.
By implementing this way, the model works and I’m able to train the model.
But my question is that, is it ok to perform grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach()
during forward pass and then backprop the final (loss+loss').backward()
?
In other words, doesn’t grads = torch.autograd.grad(loss, output_A, retain_graph=True).detach()
harm the loss.backward()
part when computing (loss+loss').backward()
?
I think it is ok since torch.autograd.grad function just manually computes the gradients, but I just want to be sure about it.
Thank you for your consideration in advance