Loss.backward() if multiple losses from multiple networks are present

I have a model which can be divided into two submodels.
Model1 takes input, then k number of Model2 takes input from Model1.
So, I have k number losses.
Currently, I’m doing loss = loss_1 + … + loss_k and then doing loss.backward().
My query is, how loss.backward will make sure that ith Model2 gradient calculation happens by loss_i, not summed up loss.

Hi Daksh, on calling loss.backward(), the grad attribute of all the leaf tensors in the computation graph of loss is populated.
The answer to your question can be best answered if you could provide some code, please.

1 Like

If the ith model contributes to the only ith loss value, the derivative of the other components of the loss with respect to that model are zero so it would happen automatically.

Another way of viewing this is that by the chain rule, loss.backward() when loss is the sum of the other losses is equivalent to calling loss_i.backward() for each loss individually.

Note that model1 would be affected by all losses as it contributes to each term by virtue of producing one output that is used by the downstream models.

import torch

torch.manual_seed(0)
m1 = torch.nn.Conv2d(1,1,1,1)
m2 = torch.nn.Conv2d(1,2,1,1)
inp = torch.randn(1,1,1,1)
loss1 = m1(inp).prod()
loss2 = m2(inp).prod()
loss = loss1 + loss2
loss.backward()
print(m1.weight.grad, m2.weight.grad)

torch.manual_seed(0)
m1 = torch.nn.Conv2d(1,1,1,1)
m2 = torch.nn.Conv2d(1,2,1,1)
inp = torch.randn(1,1,1,1)
loss1 = m1(inp).prod()
loss2 = m2(inp).prod()
loss1.backward()
loss2.backward()
print(m1.weight.grad, m2.weight.grad)
$ python3 multiple.py
tensor([[[[1.2645]]]]) tensor([[[[-0.8376]]],


        [[[-1.8029]]]])
tensor([[[[1.2645]]]]) tensor([[[[-0.8376]]],


        [[[-1.8029]]]])
1 Like