TransformerEncoder truncates output when some token positions are masked by `src_key_padding_mask` across batch

When src_key_padding_mask ([N, K]) has one column (k) all with value True, TransformerEncoder will remove this column in the output, causing inconsistency issues. Minimal example:

Class init

encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512, activation='gelu', norm_first=False, batch_first=False)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True)

Forward

x = torch.randn(2000, 6, 256)
mask = torch.ones(2000, 6)
mask[:, 0] = 0  # here I masked 5 columns instead of just one
mask = mask.bool()
out = self.transformer_encoder(src = x, src_key_padding_mask = mask)
assert out.shape[1] == 6  # Error, the actual dim is only 1

I don’t think this is a good practice because oftentimes we want aggregation on the sequence level, and this kind of removal would make the use of scatter aggregation functions, e.g. scatter_mean difficult.

Update: This relates to the option enable_nested_tensor. When setting it to False things become fine. I met tons of problems when using nester tensor and fast paths (including a previous post) – I think it would be very very helpful if PyTorch can instead set this to False by default. It’s just not as stable as it should be.

Could you create an issue on GitHub describing how the usage of NestedTensor causes issues in your use case, please?

Sure thanks, created one already