Using a combined loss to update two different models

Hi all, I’m trying to accomplish an event detection task with 2 models where model_a will produce the event tag and model_b will produce the event localization (aka label at each time frame). Basically, I’m trying to do update the two models according to the combined loss of these two models.

I tried the following code which run successfully but I am not entirely sure about the loss.backward() part:

model_a = tkmodel_a()
model_a = tkmodel_b()

criterion = nn.BCELoss() #binary cross entropy

optimizer_a = optim.Adam(model_a.parameters(), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0., amsgrad=True)

optimizer_b = optim.Adam(model_b.parameters(), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0., amsgrad=True)

if torch.cuda.is_available():
“Function to do prediction and etc”
loss_a = criterion(predicted_a, event_label)
loss_b = criterion(predicted_b, localization_label)

combined_loss_a = loss_a + loss_b
combined_loss_b = loss_a + loss_b



I have read on several posts that there is no need to do it this way and I can actually use

combined_loss = loss_a + loss_b

and based on the discussion at Optimizing based on another model's output if i do the following


I would be updating both model_a and model_b with the combined_loss by using backward only once. But is there anyone who can confirm this?

Both approaches will yield the same result, but the method using two losses is wasteful, since you need to call the backward method twice.
Here is a small code snippet to demonstrate the behavior:

# Setup

modelA = nn.Linear(1, 1)
modelB = nn.Linear(1, 1)

criterion = nn.MSELoss()

xA = torch.randn(1, 1)
targetA = torch.randn(1, 1)
xB = torch.randn(1, 1)
targetB = torch.randn(1, 1)

# Standard approach
outA = modelA(xA)
outB = modelB(xB)

lossA = criterion(outA, targetA)
lossB = criterion(outB, targetB)

loss = lossA + lossB


# other approach

outA = modelA(xA)
outB = modelB(xB)

lossA = criterion(outA, targetA)
lossB = criterion(outB, targetB)

clossA = lossA + lossB
clossB = lossA + lossB


1 Like

Okay thanks a lot for your reply.

In the beginning of the training block, should I write?


Yes, that would be necessary, if you’ve previously called .eval(), e.g. in a validation loop.
However, I would add these lines nevertheless to avoid mistakes just in case some other code part calls model.eval().

1 Like

If I only want to update modelA by modelB’s lossB?
How to do this?

Assuming modelB’s loss was calculated without the usage of modelA’s parameters, you won’t be able to directly update modelB using this loss and could instead e.g. copy the gradients from modelB to modelA manually (and update modelA afterwards).
However, you would have to make sure this approach actually makes sense.

Thanks @ptrblck, I have model A (say, MLP), and model B (say, CNN).
Now, I want to update model A using model B’s loss. Model A was not used for computing anything in model B. How can it be done in this case? Because now I don’t have the same architecture, so can’t copy the gradients from model B to model A.

I don’t know how this use case could work, since (as you’ve already explained) you won’t be able to copy the gradients from A to B.

1 Like

Is there no any way around this?
Basically, my usecase is like this:
I have model 1 which predicts which layers to freeze in model 2.
Then using the loss of model 2, I want to update model 1.

Maybe, but I don’t know as I don’t see how this use case could work.

modelA = resnet50()
modelB = nn.Linear(1, 1)

Both models have a different number of layers and of course overall structure, so copying the gradient from A to B won’t work.
Also, I claim it’s even impossible for modelB to achieve the loss created by modelB on any dataset (say ImageNet).
You can’t directly use the "loss of modelB" to update modelA since modelA was never used and would thus have to come up with your custom mapping strategy e.g. how the gradients should be copied.
Maybe there are other valid approaches which I’m unaware of.

1 Like

Thanks for the quick reply! I will think on this, maybe I will have to change the approach if this is the case.

Let’s also wait for others to chime in, as they might have some ideas or might know similar use cases which you could reuse.

1 Like

Hi @ptrblck , sorry but I have to ask a slightly different question.

Consider this use-case; I have an encoder-decoder architecture where there are two loss functions; one for the whole model while the other is for just the decoder. And I created two optimizers; one responsible for updating the model’s parameters while the other responsible for just the decoder’s.

Now, when I come to this part of the code:


It throws the following error:

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:

1) Use of a module parameter outside the `forward` function.
Please make sure model parameters are not shared across multiple concurrent
forward-backward passes. or try to use _set_static_graph() as
a workaround if this module graph does not change during training loop.

2) Reused parameters in multiple reentrant backward passes.
For example, if you use multiple `checkpoint` functions to wrap the same part of your model,
it would result in the same set of parameters been used by different reentrant backward passes
multiple times, and hence marking a variable ready multiple times. DDP does not support such use
cases in default. You can try to use _set_static_graph() as a workaround
if your module graph does not change over iterations.
Parameter at index 95 has been marked as ready twice. This means that multiple autograd engine 
hooks have fired for this particular parameter during this iteration. You can set the environment variable
TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further

Could you please help me?

Is this error only raised in a distributed setup? Could you post a minimal, executable code snippet to reproduce the error, please?

Yes, the error only occurs when using multiple GPUs. I’m so sorry, but it’s so hard to create a minimal code for this. I use the fairseq framework and there are a lot of pieces attached to it.

If you don’t mind, let’s keep discussing this in the question’s thread.