Backprop gradients w.r.t partial computation graph

Hi everyone -

I am trying to backprop gradients to different parts of a computation graph. I have the following setup:

out0 = model0(input0)
out1 = model1(out0, input1)
out2 = model1(out0, model2(out0))

loss1 = criterion(out1, ground_truth)
loss2 = criterion(out2, ground_truth)
loss3 = criterion(out2, input1)

How to achieve the following gradient updates of different networks? Specially, loss2 depends on all three models, but I only want to update model2 w.r.t. loss2.

(loss1 + loss3).backward() # w.r.t. to all three model parameters
loss2.backward() # w.r.t to only model2.parameters() 

Thanks!

I assume there is a typo in your example and model2 is the one used to generate out2 right?

If so, if you a single optimizer, the best way I can see is:

model2.zero_grad()
loss2.backward(retain_graph=True)
model0.zero_grad(); model1.zero_grad()
(loss1 + loss3).backward()

opt.step()

Actually, that’s not a typo. Model1 takes out0 and the output of model2(out0) as input to produce out2, so out2 is dependent on all three model parameters.

Ho sorry, misread it there.
It doesn’t change my answer though :smiley:

1 Like