TransformerEncoder with src_key_padding_mask on eval can produce different-sized outputs

TransformerEncoder produces different-sized outputs on eval and train modes, while a single TransformerEncoderLayer or MultiheadAttention produce the same-sized tensor on both modes. Namely, if src_key_padding_mask is redundant (if the sequence length of all samples in the batch is shorter than the length of the sequence dimension, i.e., maximum possible length), the output is reduced to the maximum sized sequence length in the batch, instead of maximum possible length. Below is a sample code:

import torch

mh = torch.nn.MultiheadAttention(128, 8, batch_first=True)
layer = torch.nn.TransformerEncoderLayer(128, 8, batch_first=True)
encoder = torch.nn.TransformerEncoder(layer, 2)

# (batch, seq_len, embed_dim)
x = torch.randn(7, 5, 128)
attn_mask = torch.tensor([[False, False, False, True, True],
                          [False, False, False, True, True],
                          [False, False, True, True, True],
                          [False, False, True, True, True],
                          [False, False, False, True, True],
                          [False, False, True, True, True],
                          [False, False, False, True, True]], dtype=torch.bool)
# the tensor could have been (7, 3, 128), but for some reason, it is (7, 5, 128)

mh_train_w_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_train_w_grad = encoder(x, src_key_padding_mask=attn_mask)

with torch.no_grad():
    mh_train_wo_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
    enc_train_wo_grad = encoder(x, src_key_padding_mask=attn_mask)

mh.eval()
encoder.eval()

mh_eval_w_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_eval_w_grad = encoder(x, src_key_padding_mask=attn_mask)

with torch.no_grad():
    mh_eval_wo_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
    enc_eval_wo_grad = encoder(x, src_key_padding_mask=attn_mask)
print(f"Train mode multi-head attention: w/grad {mh_train_w_grad.shape} w/o {mh_train_wo_grad.shape}")
print(f"Train mode transformer encoder: w/grad {enc_train_w_grad.shape} w/o {enc_train_wo_grad.shape}")
print(f"Eval mode multi-head attention: w/grad {mh_eval_w_grad.shape} w/o {mh_eval_wo_grad.shape}")
print(f"Eval mode transformer encoder: w/grad {enc_eval_w_grad.shape} w/o {enc_eval_wo_grad.shape}")

Is this an expected behavior? Of course, I could have ensured the sizes before supplying the input, but this was not the behavior I was actually expecting, especially since they behave differently on train and eval modes. I will open an issue if this is indeed an unexpected behavior. Is there a catch I am missing?

ps: This is the case for torch==2.0.0 and torch==1.13.0 on macos. When I try the same script with torch==1.12.0 it produces the output that I expect.

1 Like

I guess the discrepancy might be caused by the usage of the fast_path implementation.
CC @mikekgfb as one of the code owners.