But I think, when updating B, backward pass goes all the way back until the input x. Am I right?
What if I just want the backward to stop at y to save calculation time in case A is very complicated? Detach is not useful because in loss1 we still need the backward to go through A.
I think your approach is alright. A small improvement might be to switch the order of updating the models.
You could first call loss2.backward(retain_graph=True) and update B as you need this gradient in A a bit later. Then call loss1.backward() and update A.
Here is a small example:
modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
x = torch.randn(1, 10)
target = torch.randn(1, 10)
criterionA = nn.MSELoss()
criterionB = nn.L1Loss()
optimizerA = optim.Adam(modelA.parameters(), lr=1e-3)
optimizerB = optim.Adam(modelB.parameters(), lr=1e-3)
y = modelA(x)
z = modelB(y)
lossA = criterionA(z, target)
lossB = criterionB(z, target)
optimizerA.zero_grad()
optimizerB.zero_grad()
# modelB Update
lossB.backward(retain_graph=True)
# Both models now have the gradient using lossB
print(modelA.weight.grad)
print(modelB.weight.grad)
# Update modelB (lossB grad)
optimizerB.step()
# modelA Updata
lossA.backward()
# Now modelA has the accumulated gradients of lossA and lossB
print(modelA.weight.grad)
# Update modelA (lossA + lossB)
optimizerA.step()
Just one quick question, why do you set the lossB backward to be (retain_graph=True)? Does this option save the gradient and the other gradient will accumulate on top of it?
This option will retain the intermediate buffers, which are needed to calculate the gradient.
If you leave it as retain_graph=False, the gradient calculated using lossB will still be there, as you didn’t zeroed out the gradients (e.g. by optimizer.zero_grad()). However since the intermediate buffers are freed to save memory, the next backward call will throw an error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Setting retain_graph=True allows us to call backward on lossA, which will thereafter free the intermediate buffers, and accumulate the gradients.
You could create a detached output of modelA and perform two forward passes in modelB.
I’m not sure, if you’ll gain any performance and I’m not sure if there is a better approach.
ya = modelA(x)
yb = ya.detach()
za = modelB(ya)
zb = modelB(yb)
lossA = criterionA(za, target)
lossB = criterionB(zb, target)
Yes, both models will be updated, since lossB is calculated using both models.
To make sure it’s right, you could also have a look at the gradients after the backward call:
You’ll see that both models got valid gradients.
Also, retain_graph=True is not needed anymore, because you are not calling backward again.
The intermediate tensors might thus be freed.