When src_key_padding_mask
([N, K]) has one column (k) all with value True
, TransformerEncoder
will remove this column in the output, causing inconsistency issues. Minimal example:
Class init
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=512, activation='gelu', norm_first=False, batch_first=False)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True)
Forward
x = torch.randn(2000, 6, 256)
mask = torch.ones(2000, 6)
mask[:, 0] = 0 # here I masked 5 columns instead of just one
mask = mask.bool()
out = self.transformer_encoder(src = x, src_key_padding_mask = mask)
assert out.shape[1] == 6 # Error, the actual dim is only 1
I don’t think this is a good practice because oftentimes we want aggregation on the sequence level, and this kind of removal would make the use of scatter aggregation functions, e.g. scatter_mean
difficult.