nn.TransformerEncoder - all nan values when src_key_padding_mask provided

I have a use case where I am dealing with sequences of variables lengths. I am looking to make use of nn.TransformerEncoder, but for some reason if a given sequence is of a length < max_length of sequence, all values result in nan in the forward pass.

Explanation of the code

I am including the fully reproducible code below.

In the example below, batch_size is 2, the maximum sequence length is 3, and I provide two sequences: one with 2 tokens and the other one with 3 tokens. Naturally, the sequence with 2 tokens needs to be padded in order to be fed to nn.TransformerEncoder.

In order to do this, I need to provide src_key_padding_mask of shape (N, S) where N is the batch_size and S is the sequence_length, in order to provide per-batch padding mask. For the sequence of length 2, a value of True needs to be provided for the position that is being padded - this means that for this particular sequence the padding values will be [False, False, True]

To my surprise, providing the src_key_padding_mask results with all nan values for the entire sequence where the sequence length < max_sequence_length.

Note: I have looked through all the other issues where other users were getting all nan values, but their scenarios were due to “user error”, where they provided sequences where all tokens are padded. This is not the case with my scenario.

from torch.nn.utils.rnn import pad_sequence
import torch
import torch.nn as nn
ego_encoder_layer = nn.TransformerEncoderLayer(d_model=2048, nhead=16, activation='gelu')
ego_transformer_encoder = nn.TransformerEncoder(ego_encoder_layer, num_layers=6)

sequences = []
batch_size = 2
src_padding_mask = torch.zeros((batch_size, 3)).type(torch.bool)

sequences.append(torch.rand(2, 2048))
sequences.append(torch.rand(3, 2048))
seq_lengths = [2, 3]

ego_seq2 = pad_sequence(sequences, padding_value=float('-inf'))

ego_transformer_features = ego_transformer_encoder(ego_seq2, src_key_padding_mask=src_padding_mask)

max_seq_length = 3
for ind, sl in enumerate(seq_lengths):
    for j in range(sl, max_seq_length):
        src_padding_mask[ind, j] = True # True means this is a padding token - skip

print('- seq_lengths, src_padding_mask -')
print(seq_lengths, src_padding_mask)

print('transformer features')
ego_transformer_features[:, :, :5]


torch.Size([3, 2, 2048])
torch.Size([3, 2, 2048])
- seq_lengths, src_padding_mask -
[2, 3] tensor([[False, False,  True],
        [False, False, False]])
transformer features
torch.Size([3, 2, 2048])
tensor([[[    nan,     nan,     nan,     nan,     nan],
         [ 1.0203, -0.1215,  1.7757,  2.8037,  0.2745]],

        [[    nan,     nan,     nan,     nan,     nan],
         [ 1.4554,  0.1737,  1.9316,  2.8318,  1.0204]],

        [[    nan,     nan,     nan,     nan,     nan],
         [ 0.4414,  1.2467,  0.6188,  2.2753,  0.2954]]],

@albanD Any thoughts on this?