Passing padded tensors as input to non-recurrent models


I am using the output from a transformer encoder as input to an MLP network. The input to the encoder is padded with zeroes so all sequences are the same length, and I pass src_key_padding_mask=mask to the encoder. The MLP operates on the output from each time-step from the encoder, i.e. taking as input sequence_length * batch size vectors with shape encoder_dim:

Padded Batch: [ seq_length, batch_size, dim ] →
Encoder output: [ seq_length, batch_size, encoder_dim ] →
MLP output: [ seq_length, batch_size, output_dim ]

Since the encoder output includes padded values, the MLP is producing output from those values, which would affect the loss. What is the best way to avoid this? I have been zeroing out the predicted values with a binary mask, and calculating the loss as follows:

def sum_mask_loss(mask, output, labels, loss_func):

    newmask = torch.unsqueeze((~mask).int(), 2).repeat(1, 1, output.shape[2])
    output = newmask * output
    loss = loss_func(output, labels) / newmask.sum()

    return loss

But I am not sure this is the best way. Any advice would be appreciated!