How to update a network with a specific loss when there are multiple losses

Let’s say I have a cascaded model:
e.g., x → modelA → modelB → output <-loss-> gt

now, there are multiple losses (loss1, loss2, loss3) I’d like to calculate but I want to backpropagate only loss1 and loss2 to modelA, not loss3.
The modelB needs to be updated using every loss though.

I could do it using retain_graph=True, but it requires more gpu memory.
I feel like I can do this by masking the gradient of a certain loss function but I’m not sure how I can do this.

Thank you.

Assuming loss1 and loss2 are supposed to create the gradients for modelA and modelB, you could use loss3.backward(retain_graph=True, inputs=modelB.parameters()) and call loss.backward() on the accumulated loss from loss1 and loss2 afterwards.

It would require more memory if you don’t free the computation graph in the last backward call, which shouldn’t be the case.

1 Like

Thank you for the prompt and accurate solution!!!