How (and where) does pytorch keep track of everything while computing and propagating losses through a network?

I am struggling to understand how pytorch keeps track of which loss belongs to which output. Let me explain (all the code used below comes from this implementation of a text segmentation algorithm, which uses a Fully Convolutional Net: [https://github.com/princewang1994/TextSnake.pytorch]).
The last layers of the CNN looks like this:

        self.predict = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(16, self.output_channel, kernel_size=1, stride=1, padding=0)
        ) # with self.output_channel = 7

Basically the CNN has 7 outputs. Then the algorithm is computing the losses:

class TextLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def ohem(self, predict, target, train_mask, negative_ratio=3.):
        pos = (target * train_mask).byte()
        neg = ((1 - target) * train_mask).byte()

        n_pos = pos.float().sum()

        if n_pos.item() > 0:
            loss_pos = F.cross_entropy(predict[pos], target[pos], reduction='sum')
            loss_neg = F.cross_entropy(predict[neg], target[neg], reduction='none')
            n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
        else:
            loss_pos = 0.
            loss_neg = F.cross_entropy(predict[neg], target[neg], reduction='none')
            n_neg = 100
        loss_neg, _ = torch.topk(loss_neg, n_neg)

        return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()

    def forward(self, input, tr_mask, tcl_mask, sin_map, cos_map, radii_map, train_mask):
        """
        calculate textsnake loss
        :param input: (Variable), network predict, (BS, 7, H, W)
        :param tr_mask: (Variable), TR target, (BS, H, W)
        :param tcl_mask: (Variable), TCL target, (BS, H, W)
        :param sin_map: (Variable), sin target, (BS, H, W)
        :param cos_map: (Variable), cos target, (BS, H, W)
        :param radii_map: (Variable), radius target, (BS, H, W)
        :param train_mask: (Variable), training mask, (BS, H, W)
        :return: loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos
        """

        tr_pred = input[:, :2].permute(0, 2, 3, 1).contiguous().view(-1, 2)  # (BSxHxW, 2)
        tcl_pred = input[:, 2:4].permute(0, 2, 3, 1).contiguous().view(-1, 2)  # (BSxHxW, 2)
        sin_pred = input[:, 4].contiguous().view(-1)  # (BSxHxW,)
        cos_pred = input[:, 5].contiguous().view(-1)  # (BSxHxW,)

        # regularize sin and cos: sum to 1
        scale = torch.sqrt(1.0 / (sin_pred ** 2 + cos_pred ** 2))
        sin_pred = sin_pred * scale
        cos_pred = cos_pred * scale

        radii_pred = input[:, 6].contiguous().view(-1)  # (BSxHxW,)
        train_mask = train_mask.view(-1)  # (BSxHxW,)

        tr_mask = tr_mask.contiguous().view(-1)
        tcl_mask = tcl_mask.contiguous().view(-1)
        radii_map = radii_map.contiguous().view(-1)
        sin_map = sin_map.contiguous().view(-1)
        cos_map = cos_map.contiguous().view(-1)

        # loss_tr = F.cross_entropy(tr_pred[train_mask], tr_mask[train_mask].long())
        loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())

        loss_tcl = 0.
        tr_train_mask = train_mask * tr_mask
        if tr_train_mask.sum().item() > 0:
            loss_tcl = F.cross_entropy(tcl_pred[tr_train_mask], tcl_mask[tr_train_mask].long())

        # geometry losses
        loss_radii, loss_sin, loss_cos = 0., 0., 0.
        tcl_train_mask = train_mask * tcl_mask
        if tcl_train_mask.sum().item() > 0:
            ones = radii_map.new(radii_pred[tcl_mask].size()).fill_(1.).float()
            loss_radii = F.smooth_l1_loss(radii_pred[tcl_mask] / radii_map[tcl_mask], ones)
            loss_sin = F.smooth_l1_loss(sin_pred[tcl_mask], sin_map[tcl_mask])
            loss_cos = F.smooth_l1_loss(cos_pred[tcl_mask], cos_map[tcl_mask])

        return loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos

Finally, it’s adding up all the the losses before backprograting the total loss.

I have 2 questions:
1/ Regarding the ohem loss (first function of the class TextLoss): 2 of the 7 outputs are inputed to this function (as ‘predict’), and it is nowhere specified what which output should be, but the algo still manages to train them as expected. Also when inspecting the output tensor (i.e. the error), it only has one value, so how does pytorch know how to train differently the two masks?

2/ Similarly, the sum of all the errors is a tensor with a single value, so how does pytorch know which error corresponds to each output?

Maybe looking at the second question already provides insights of to the first:

So at the end of the loss calculation, there typically is a sum or mean to reduce to the value. If you work out the chain rule for this summation, this is mathematically equivalent to having each of the individual losses backpropagated through.

Turning back to the first question: In addition to differentiating all losses, when you mask things y = mask * x, the chain rule makes the mask to be applied to the gradient, too: dloss/dx = mask * dloss/dy. So there is the mask.
This is also why, when the loss or loss gradient becomes undefined (NaN), you need to use torch.where(mask, x, torch.zeros((), device=x.device, dtype=x.dtype) for masking, else you have mask * NaN = NaN`

Best regards

Thomas

Thank you for your prompt reply Thomas. Regarding your answer to my second question, I completely understand that the error is ‘shared’ in the biggest part of the network. However, in the last layer, there are seven 1 * 1 * 16 kernels which depend only on their corresponding output, so how are these kernels trained? Indeed, if these kernels use the total error to train, then all the outputs will be train to do the same thing?

Also I think I have been a bit unclear regarding my first question (I misused the word mask). Basically among the two outputs passed to the ohem function, one is trained to output high values for a given pixel if it belongs to a text region and low values otherwise, and the other output is trained to do the exact opposite. How does the algorithm train them to do such different tasks? Indeed since only one error is outputed, intuitively, it seems to me that they are trained to do the same thing (although it’s not what’s happening of course).