How to correctly weight MSE loss for padded sequences

I have embedded sequences of size (batch, sequence, dim), and a padding mask of size ((batch, sequence), False meaning “padding”.
Imagine a NLP problem, where I have different sentences of various number of words while keeping this number below a threshold of “sequence”.

Here an example:

y_pred = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
keep_mask = torch.tensor([[True, True, False],
                          [True, True, True]])

I think a fair way would be weighting this loss, such that small sequences contribute to the loss as much as big sequences. Mathematically I don’t want to bias my model to do better on bigger sequences. I have not found a better than:

loss_fn = nn.MSELoss(reduction="none")
weight_sequence = mask / mask.sum(-1, keepdim=True)
loss = loss_fn(y_pred, y) * weight_sequence.unsqueeze(-1)
loss = loss.mean()

Maybe you could suggest a better way to perform this or a solution my problem.

Hi,
can you specify the problem on some example?
Dou you want to create some network for token generation/token classification, where output can be shorter than the length of the target?
MSE is a token based metrics, so the number of tokens shouldn’t matter in my opinion and should be aggregated through the batch. But if you tell more about the problem, I might be able to provide better answer.

Your current approach to calculating the loss is reasonable, but there might be a more efficient way to achieve the same goal. You can try the following approach:

  1. Mask the predicted and target sequences using the keep_mask.
  2. Calculate the loss for each non-padded sequence element.
  3. Divide the sum of losses for each batch by the number of non-padded elements in the batch.

Try this

y_pred = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
keep_mask = torch.tensor([[True, True, False],
                          [True, True, True]])

y_pred_masked = y_pred * keep_mask.unsqueeze(-1)
y_masked = y * keep_mask.unsqueeze(-1)
loss_fn = nn.MSELoss(reduction="none")
loss = loss_fn(y_pred_masked, y_masked)
num_non_padded = keep_mask.sum(-1, keepdim=True).unsqueeze(-1)
mean_loss = loss.sum() / num_non_padded.sum()