Branching (across channels) the output of a network into 2 losses and combining to a single loss

The output of my network has 6 channels. In my custom loss function, I break the 6 channels into two groups of 3 channels each. Process the first group using l1_loss (got a scalar output), the second group using bce_loss (got a scalar output), & add the two scalars obtained to get the final loss. Then I use loss.backward().

Would autograd handle the gradient flow automatically in above case? If yes, how is autograd doing this?


class MyLoss(torch.nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
	def forward(self, device, output_batch, target_batch):
		output_batch_part_1 = output_batch[:,0:3,:,:] # [batch, channel, W, H]
		output_batch_part_2 = output_batch[:,3:6,:,:]
		target_batch_part_1 = target_batch[:,0:3,:,:]
		target_batch_part_2 = target_batch[:,3:6,:,:]
		loss_term_1 = self.l1_loss_calculation(output_batch_part_1, target_batch_part_1)
		#... some processing on output_batch_part_2 ...
		loss_term_2 = self.bce_loss_calulation(output_batch_part_2, target_batch_part_2)
		#... some normalizations on loss_term_2...
		final_loss = loss_term_1 + loss_term_2
		return final_loss

Yes, Autograd tracks the indexing operation as well as additions, so your model should get valid gradients as shown in this example:

x = torch.randn(1, 6, 4, 4, requires_grad=True)
x0 = x[:, :3]
x1 = x[:, 3:]

1 Like