# Back Propagating the loss in joint training

I have a joint network where the data go through two separate subnetworks and get concatenated before sending through a third network. I am just wondering how can I back propagate two losses such a way that the optimizer will update net1’s parameter by loss1 and net2’s parameters by loss2

Models could be these,

``````class net1(nn.Module):
def __init__(self, input_size, hidden_size, num_classes, p=0.0):
super().__init__()
self.embeddings = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU()
)
def forward(self, x):
return self.embeddings(x)

class net2(nn.Module):
def __init__(self, input_size, hidden_size, num_classes, p=0.0):
super().__init__()
self.embeddings = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU()
)
def forward(self, x):
return self.embeddings(x)
``````

Training loop:

``````loss1 = net1(data)
loss2 = net2(data)
loss = loss1 + loss2
loss.backward()
optimizer.step()
``````

I am just wondering, would that make sure net1 would be updated by ‘loss1’ and net2 by ‘loss2’?
or it will get a common updates from ‘loss’

Hi,

Looking at your training loop, yes net1 will be updated by ‘loss1’ and net2 by ‘loss2’, as derivative of a sum is the sum of the derivatives. You can verify it by comparing the gradient of your models when doing backward on the total loss versus doing backward on the individual losses, ex:

``````>>> # Create dummy parameters
>>> param_1 = torch.nn.Parameter(torch.ones(1))
>>> param_2 = torch.nn.Parameter(torch.ones(1))
>>>
>>> # Create dummy losses
>>> loss1 = param_1*5
>>> loss2 = 2**param_2
>>>
>>> # Backward separatly on loss1 and loss2
>>> loss1.backward()
>>> loss2.backward()
>>>
>>> # Check gradient of param_1 and 2
tensor([5.])
tensor([1.3863])
>>>
>>>
>>> # Let's redo the experiment with the same dummy parameters and losses
>>> param_1 = torch.nn.Parameter(torch.ones(1))
>>> param_2 = torch.nn.Parameter(torch.ones(1))
>>> loss1 = param_1*5
>>> loss2 = 2**param_2
>>>
>>> # But this time let's backward on their sum
>>> loss = loss1 + loss2
>>>
>>> # Let's check that the gradient are the same as earlier:
>>> loss.backward()