I have embedded sequences of size (batch, sequence, dim), and a padding mask of size ((batch, sequence), False meaning “padding”.

Imagine a NLP problem, where I have different sentences of various number of words while keeping this number below a threshold of “sequence”.

Here an example:

```
y_pred = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
keep_mask = torch.tensor([[True, True, False],
[True, True, True]])
```

I think a fair way would be weighting this loss, such that small sequences contribute to the loss as much as big sequences. Mathematically I don’t want to bias my model to do better on bigger sequences. I have not found a better than:

```
loss_fn = nn.MSELoss(reduction="none")
weight_sequence = mask / mask.sum(-1, keepdim=True)
loss = loss_fn(y_pred, y) * weight_sequence.unsqueeze(-1)
loss = loss.mean()
```

Maybe you could suggest a better way to perform this or a solution my problem.

Hi,

can you specify the problem on some example?

Dou you want to create some network for token generation/token classification, where output can be shorter than the length of the target?

MSE is a token based metrics, so the number of tokens shouldn’t matter in my opinion and should be aggregated through the batch. But if you tell more about the problem, I might be able to provide better answer.

Your current approach to calculating the loss is reasonable, but there might be a more efficient way to achieve the same goal. You can try the following approach:

- Mask the predicted and target sequences using the keep_mask.
- Calculate the loss for each non-padded sequence element.
- Divide the sum of losses for each batch by the number of non-padded elements in the batch.

Try this

```
y_pred = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
keep_mask = torch.tensor([[True, True, False],
[True, True, True]])
y_pred_masked = y_pred * keep_mask.unsqueeze(-1)
y_masked = y * keep_mask.unsqueeze(-1)
loss_fn = nn.MSELoss(reduction="none")
loss = loss_fn(y_pred_masked, y_masked)
num_non_padded = keep_mask.sum(-1, keepdim=True).unsqueeze(-1)
mean_loss = loss.sum() / num_non_padded.sum()
```