Terminate backward pass at different depth for different loss functions

Suppose I have two networks A, B in sequence. Two different loss functions are applied on the resulting features, i.e. in forward pass I have:

y = A(x)
z = B(y)
loss1 = loss_func1(z)
loss2 = loss_func2(z)

loss1 only update network A, and loss2 update both network A and B.
I have optimizer for both A and B independently. Now what I do is:

optim_A.zero_grad()
(loss1 + loss2).backward()
optim_A.step()
optim_B.zero_grad()
loss2.backward()
optim_B.step()

But I think, when updating B, backward pass goes all the way back until the input x. Am I right?
What if I just want the backward to stop at y to save calculation time in case A is very complicated? Detach is not useful because in loss1 we still need the backward to go through A.

Thanks!

I think your approach is alright. A small improvement might be to switch the order of updating the models.
You could first call loss2.backward(retain_graph=True) and update B as you need this gradient in A a bit later. Then call loss1.backward() and update A.
Here is a small example:


modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)

x = torch.randn(1, 10)
target = torch.randn(1, 10)

criterionA = nn.MSELoss()
criterionB = nn.L1Loss()

optimizerA = optim.Adam(modelA.parameters(), lr=1e-3)
optimizerB = optim.Adam(modelB.parameters(), lr=1e-3)

y = modelA(x)
z = modelB(y)

lossA = criterionA(z, target)
lossB = criterionB(z, target)

optimizerA.zero_grad()
optimizerB.zero_grad()

# modelB Update
lossB.backward(retain_graph=True)

# Both models now have the gradient using lossB
print(modelA.weight.grad)
print(modelB.weight.grad)

# Update modelB (lossB grad)
optimizerB.step()

# modelA Updata
lossA.backward()

# Now modelA has the accumulated gradients of lossA and lossB
print(modelA.weight.grad)

# Update modelA (lossA + lossB)
optimizerA.step()
1 Like

Thank you for your answer!

Just one quick question, why do you set the lossB backward to be (retain_graph=True)? Does this option save the gradient and the other gradient will accumulate on top of it?

1 Like

This option will retain the intermediate buffers, which are needed to calculate the gradient.
If you leave it as retain_graph=False, the gradient calculated using lossB will still be there, as you didn’t zeroed out the gradients (e.g. by optimizer.zero_grad()). However since the intermediate buffers are freed to save memory, the next backward call will throw an error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Setting retain_graph=True allows us to call backward on lossA, which will thereafter free the intermediate buffers, and accumulate the gradients.

1 Like

A follow up question:
In the same setting, what if loss1 only updates network A, and loss2 only updates network B, i.e.

optim_B.zero_grad()
loss2.backward(retain_graph=True)
optim_B.step()

optim_A.zero_grad()
loss1.backward()
optim_A.step()

Can we save calculation time by prevent back-propagation from loss2 through A?

Thank you for your time!

You could create a detached output of modelA and perform two forward passes in modelB.
I’m not sure, if you’ll gain any performance and I’m not sure if there is a better approach.

ya = modelA(x)
yb = ya.detach()
za = modelB(ya)
zb = modelB(yb)

lossA = criterionA(za, target)
lossB = criterionB(zb, target)
1 Like

@ptrblck can you please give me your input here:
what happens if i do this:

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)

x = torch.randn(1, 10)
target = torch.randn(1, 10)

criterionA = nn.MSELoss()
criterionB = nn.L1Loss()

optimizerA = optim.Adam(modelA.parameters(), lr=1e-3)
optimizerB = optim.Adam(modelB.parameters(), lr=1e-3)

y = modelA(x)
z = modelB(y)

# There is no lossA
lossB = criterionB(z, target)


optimizerA.zero_grad()
optimizerB.zero_grad()


# modelB Update
lossB.backward(retain_graph=True)


# Update modelB (lossB grad)
optimizerB.step()

# Can I update model A in this way?
optimizerA.step()

I dont have lossA any more, but im doing optimizerB.step() and optimizerA.step(). can i updaqte modelA and modelB in this way?

Yes, both models will be updated, since lossB is calculated using both models.
To make sure it’s right, you could also have a look at the gradients after the backward call:

print(modelA.weight.grad)
print(modelB.weight.grad)

You’ll see that both models got valid gradients.
Also, retain_graph=True is not needed anymore, because you are not calling backward again.
The intermediate tensors might thus be freed.

1 Like