How to Detach specific components in the loss?

I’m a little confused how to detach certain model from the loss computation graph.
If I have 3 models that generate an output: A, B and C, given an input.
And, 3 optimizers for those 3 models: Oa, Ob and Oc, for A, B and C resp.
Let’s assume I have an initial input x.
I also have 3 losses L1, L2 and L3.

Let’s assume a scenario where the system does L2 that calculates B(A(x)), amongst other things.

If we want a loss to optimize only certain models. eg. If we want to optimize only B we can do it 2 ways:

1)
op = A(x)
op2 = B(op.detach())
cost = L2(op2, ground_truth)
cost.backward()
Ob.step()
#This would only backprop. gradients and optimize B.

2)
op = A(x)
op2 = B(op)
cost = L2(op2, ground_truth)
cost.backward()
Ob.step()
#This would backprop. loss to both A and B, 
#but it doesn't matter because we are just doing Ob.step()

Now in this if we wanted to optimize A for B(A(x)) we can’t do (1) we can do (2).

op = A(x)
op2 = B(op)
cost = L2(op2, ground_truth)
cost.backward()
Oa.step()
#Here gradients are calculated for both `A` and `B`, but it doesn't matter because we are
#just optimizing A with Oa.

I have understood this but let’s say that:
netLoss = L1 + L2 + L3
In this I have the second case above. L2 has both A and B in the computation graph and we only want to optimize A with respect to L2. The problem is that we want to do netLoss.backward() and L1 + L2 + L3 contains A, B and C in them total.

So after netLoss.backward() we have to do:
Oa.step()
Ob.step()
Oc.step()

So, Ob will take gradients from L2 as well, and form the other losses that has it also. Is there some way to do netLoss.backward() when we can exclude certain models from the graph? Eg. in (2) above where we want to update only A in B(A(x)).

@apaszke @SimonW @albanD

The information we would need to give you the best possible schedule includes how those losses are exactly influenced by individual modules.

@apaszke Please consider the Eq. 8, 9, 10, 11. in https://arxiv.org/pdf/1711.06861.pdf
Eq. 8 - Minimize CrossEntropy Loss
Eq 9 - Maximize Entropy of Label(minimizing the negative of Entropy)
which is a different Loss not present in the package but can be created as: -∑ p.log(p) and i’ll have to manually construct it.
Eq 10 - Minimize CrossEntropy Loss

The issue is with Eq 9, where the computation is like: op = Classifier(Encoder(x)) and the Loss only wants to optimize the Encoder model for the op loss.

Like i mentioned in the example above, this would be easy to do if the loss was backpropogated individually, I could just do Encoder_Optimizer.step()

But, the paper model is end-to-end trained which means that they are just backpropogating Eq. 11; where Eq 11 = Eq 8 + Eq 9 + Eq 10

The problem here is that since some other loss ie. Eq 8 wants to optimize the Classifier model so the Classifier will end up being updated by gradients from both Eq 8 and Eq 9.

I hope i’m understanding it correctly, and, Is there some way to do this?

Doing multiple backward passes before calling any optimizers will give you end to end training too. It’s just a name for joint optimization. I didn’t have time to look into the paper, but it’s possible that you could backprop from Eq9 loss first, zero grads of Classifier, and then backprop from Eq8 + Eq10 (remember that .grad is accumulated instead of overwritten).

Ahhh… I did not know that we could backprop. losses individually as well. Yes, this makes perfect sense. Thanks a lot. This really helped me out.

@apaszke @ptrblck I just had a small follow up, little different query.
If model1 gives an output a which is then given to model2 and maybe model3 also later on and we want to make sure that different models are updated based on different losses then how does .detach() work.

Let’s say a.detach() is the input to model2 so if we calculate loss and do .backward() gradients will only be calculated for model2 not model1.

But then in the next step when a is given to model3 and the new loss is done .backward() on does the gradient get calculated for model3 as well as model1 or just model3?

Basically, does calling .detach() return a copy of the original output ‘a’ or does it detach the original a from computation chain?

.detach() will return a detached version of your tensor. You can reuse the “attached” tensor for further computations as long as you don’t reassign it somehow:

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)

x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)

c = modelC(a)
c.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)
4 Likes

@ptrblck Thanks a lot. This was really helpful :slight_smile:

1 Like

@ptrblck What about if modelC = modelB and I want to backpropagate through modelB, but optimize modelA?

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)

# a is needed for both losses
x = torch.randn(1, 10)
a = modelA(x)

b1 = modelB(a.detach())
loss1 = loss_function_1(b1)
loss1.mean().backward()
# Here call optimForB.step()

b2 = modelB(a)
loss2 = loss_function_2(b2)
loss2.mean().backward()
# Here call optimForA.step()

I will need to compute

b1 = modelB(a.detach())

for the first loss and

b2 = modelB(a)

for the second. This seems kinda wasteful imo. Is there a way to detach after calculation, so I only have to calculate b = modelB(a) once?

I don’t see another approch, if the loss functions are different (in you current code snippet they are the same, so you could just use a single forward pass, but I assume that’s a typo).
For the second forward pass, you could set the requires_grad attributes of all parameters in modelB to False, if you don’t need these gradients.

1 Like

Okay, thanks. Good to know!