I have a use case where I am dealing with sequences of variables lengths. I am looking to make use of
nn.TransformerEncoder, but for some reason if a given sequence is of a length < max_length of sequence, all values result in
nan in the forward pass.
I am including the fully reproducible code below.
In the example below,
batch_size is 2, the maximum sequence length is 3, and I provide two sequences: one with 2 tokens and the other one with 3 tokens. Naturally, the sequence with 2 tokens needs to be padded in order to be fed to
In order to do this, I need to provide
src_key_padding_mask of shape
(N, S) where N is the
batch_size and S is the
sequence_length, in order to provide per-batch padding mask. For the sequence of length 2, a value of True needs to be provided for the position that is being padded - this means that for this particular sequence the padding values will be
[False, False, True]
To my surprise, providing the
src_key_padding_mask results with all
nan values for the entire sequence where the sequence length < max_sequence_length.
Note: I have looked through all the other issues where other users were getting all
nan values, but their scenarios were due to “user error”, where they provided sequences where all tokens are padded. This is not the case with my scenario.
from torch.nn.utils.rnn import pad_sequence import torch import torch.nn as nn ego_encoder_layer = nn.TransformerEncoderLayer(d_model=2048, nhead=16, activation='gelu') ego_transformer_encoder = nn.TransformerEncoder(ego_encoder_layer, num_layers=6) sequences =  batch_size = 2 src_padding_mask = torch.zeros((batch_size, 3)).type(torch.bool) sequences.append(torch.rand(2, 2048)) sequences.append(torch.rand(3, 2048)) seq_lengths = [2, 3] ego_seq2 = pad_sequence(sequences, padding_value=float('-inf')) print(ego_seq2.shape) ego_transformer_features = ego_transformer_encoder(ego_seq2, src_key_padding_mask=src_padding_mask) print(ego_transformer_features.shape) max_seq_length = 3 for ind, sl in enumerate(seq_lengths): for j in range(sl, max_seq_length): src_padding_mask[ind, j] = True # True means this is a padding token - skip print('- seq_lengths, src_padding_mask -') print(seq_lengths, src_padding_mask) print('transformer features') print(ego_transformer_features.shape) ego_transformer_features[:, :, :5]
torch.Size([3, 2, 2048]) torch.Size([3, 2, 2048]) - seq_lengths, src_padding_mask - [2, 3] tensor([[False, False, True], [False, False, False]]) transformer features torch.Size([3, 2, 2048]) tensor([[[ nan, nan, nan, nan, nan], [ 1.0203, -0.1215, 1.7757, 2.8037, 0.2745]], [[ nan, nan, nan, nan, nan], [ 1.4554, 0.1737, 1.9316, 2.8318, 1.0204]], [[ nan, nan, nan, nan, nan], [ 0.4414, 1.2467, 0.6188, 2.2753, 0.2954]]], grad_fn=<SliceBackward>)