Using a combined loss to update two different models

Hi all, I’m trying to accomplish an event detection task with 2 models where model_a will produce the event tag and model_b will produce the event localization (aka label at each time frame). Basically, I’m trying to do update the two models according to the combined loss of these two models.

I tried the following code which run successfully but I am not entirely sure about the loss.backward() part:

model_a = tkmodel_a()
model_a = tkmodel_b()

criterion = nn.BCELoss() #binary cross entropy

optimizer_a = optim.Adam(model_a.parameters(), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0., amsgrad=True)

optimizer_b = optim.Adam(model_b.parameters(), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0., amsgrad=True)

if torch.cuda.is_available():
model_a.cuda()
model_b.cuda()
“Function to do prediction and etc”
loss_a = criterion(predicted_a, event_label)
loss_b = criterion(predicted_b, localization_label)

combined_loss_a = loss_a + loss_b
combined_loss_b = loss_a + loss_b

optimizer_a.zero_grad(retain_graph=True)
combined_loss_a.backward()
optimizer_a.step()

optimizer_b.zero_grad()
combined_loss_b.backward()
optimizer_b.step()

I have read on several posts that there is no need to do it this way and I can actually use

combined_loss = loss_a + loss_b
combined_loss.backward()

and based on the discussion at Optimizing based on another model's output if i do the following

optimizer_strong.zero_grad()
optimizer_weak.zero_grad()
combined_strong_loss.backward()
optimizer_strong.step()
optimizer_weak.step()

I would be updating both model_a and model_b with the combined_loss by using backward only once. But is there anyone who can confirm this?

Both approaches will yield the same result, but the method using two losses is wasteful, since you need to call the backward method twice.
Here is a small code snippet to demonstrate the behavior:

# Setup
torch.manual_seed(2809)

modelA = nn.Linear(1, 1)
modelB = nn.Linear(1, 1)

criterion = nn.MSELoss()

xA = torch.randn(1, 1)
targetA = torch.randn(1, 1)
xB = torch.randn(1, 1)
targetB = torch.randn(1, 1)

# Standard approach
outA = modelA(xA)
outB = modelB(xB)

lossA = criterion(outA, targetA)
lossB = criterion(outB, targetB)

loss = lossA + lossB
loss.backward()

print(modelA.weight.grad)
print(modelB.weight.grad)

# other approach
modelA.zero_grad()
modelB.zero_grad()

outA = modelA(xA)
outB = modelB(xB)

lossA = criterion(outA, targetA)
lossB = criterion(outB, targetB)

clossA = lossA + lossB
clossB = lossA + lossB

clossA.backward(retain_graph=True)
print(modelA.weight.grad)

modelB.zero_grad()
clossB.backward()
print(modelB.weight.grad)

Okay thanks a lot for your reply.