Hi everyone -
I am trying to backprop gradients to different parts of a computation graph. I have the following setup:
out0 = model0(input0)
out1 = model1(out0, input1)
out2 = model1(out0, model2(out0))
loss1 = criterion(out1, ground_truth)
loss2 = criterion(out2, ground_truth)
loss3 = criterion(out2, input1)
How to achieve the following gradient updates of different networks? Specially, loss2 depends on all three models, but I only want to update model2 w.r.t. loss2.
(loss1 + loss3).backward() # w.r.t. to all three model parameters
loss2.backward() # w.r.t to only model2.parameters()
Thanks!