src_key_padding_mask ([N, K]) has one column (k) all with value
TransformerEncoder will remove this column in the output, causing inconsistency issues. Minimal example:
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512, activation='gelu', norm_first=False, batch_first=False) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True)
x = torch.randn(2000, 6, 256) mask = torch.ones(2000, 6) mask[:, 0] = 0 # here I masked 5 columns instead of just one mask = mask.bool() out = self.transformer_encoder(src = x, src_key_padding_mask = mask) assert out.shape == 6 # Error, the actual dim is only 1
I don’t think this is a good practice because oftentimes we want aggregation on the sequence level, and this kind of removal would make the use of scatter aggregation functions, e.g.