@ptrblck, I have a follow-up question to that.
Using torch.mean()
implies that the elements that were zeroed out are also taken into account when calculating the average and thus affect the backpropagation. I’m wondering if it makes more sense to divide by the count of non-zeroed out elements.
pad = 2
tags = torch.tensor([0,1,1,0,1,2])
# for this example, let's pretend this is our loss tensor that we got from the unreduced BCEWithLogitsLoss
loss = torch.tensor([0.001, -0.3, 0.9, 0.7,0.6, 0.8])
loss_mask = tags != pad
# loss_mask tensor([ True, True, True, True, True, False])
loss_masked = loss.where(loss_mask, torch.tensor(0.0))
# loss_masked tensor([ 0.0010, -0.3000, 0.9000, 0.7000, 0.6000, 0.0000])
loss_masked.mean() # tensor(0.3168)
loss_masked.sum() / loss_mask.sum() # tensor(0.3802)