I’m a little confused how to detach certain model from the loss computation graph.
If I have 3 models that generate an output: A
, B
and C
, given an input.
And, 3 optimizers for those 3 models: Oa
, Ob
and Oc
, for A, B and C resp.
Let’s assume I have an initial input x
.
I also have 3 losses L1
, L2
and L3
.
Let’s assume a scenario where the system does L2
that calculates B(A(x))
, amongst other things.
If we want a loss to optimize only certain models. eg. If we want to optimize only B
we can do it 2 ways:
1)
op = A(x)
op2 = B(op.detach())
cost = L2(op2, ground_truth)
cost.backward()
Ob.step()
#This would only backprop. gradients and optimize B.
2)
op = A(x)
op2 = B(op)
cost = L2(op2, ground_truth)
cost.backward()
Ob.step()
#This would backprop. loss to both A and B,
#but it doesn't matter because we are just doing Ob.step()
Now in this if we wanted to optimize A
for B(A(x))
we can’t do (1)
we can do (2)
.
op = A(x)
op2 = B(op)
cost = L2(op2, ground_truth)
cost.backward()
Oa.step()
#Here gradients are calculated for both `A` and `B`, but it doesn't matter because we are
#just optimizing A with Oa.
I have understood this but let’s say that:
netLoss = L1 + L2 + L3
In this I have the second case above. L2 has both A and B in the computation graph and we only want to optimize A with respect to L2. The problem is that we want to do netLoss.backward() and L1 + L2 + L3 contains A, B and C in them total.
So after netLoss.backward() we have to do:
Oa.step()
Ob.step()
Oc.step()
So, Ob
will take gradients from L2 as well, and form the other losses that has it also. Is there some way to do netLoss.backward() when we can exclude certain models from the graph? Eg. in (2) above where we want to update only A
in B(A(x)).