TransformerEncoder
produces different-sized outputs on eval and train modes, while a single TransformerEncoderLayer
or MultiheadAttention
produce the same-sized tensor on both modes. Namely, if src_key_padding_mask
is redundant (if the sequence length of all samples in the batch is shorter than the length of the sequence dimension, i.e., maximum possible length), the output is reduced to the maximum sized sequence length in the batch, instead of maximum possible length. Below is a sample code:
import torch
mh = torch.nn.MultiheadAttention(128, 8, batch_first=True)
layer = torch.nn.TransformerEncoderLayer(128, 8, batch_first=True)
encoder = torch.nn.TransformerEncoder(layer, 2)
# (batch, seq_len, embed_dim)
x = torch.randn(7, 5, 128)
attn_mask = torch.tensor([[False, False, False, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, True, True]], dtype=torch.bool)
# the tensor could have been (7, 3, 128), but for some reason, it is (7, 5, 128)
mh_train_w_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_train_w_grad = encoder(x, src_key_padding_mask=attn_mask)
with torch.no_grad():
mh_train_wo_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_train_wo_grad = encoder(x, src_key_padding_mask=attn_mask)
mh.eval()
encoder.eval()
mh_eval_w_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_eval_w_grad = encoder(x, src_key_padding_mask=attn_mask)
with torch.no_grad():
mh_eval_wo_grad, _ = mh(x, x, x, key_padding_mask=attn_mask)
enc_eval_wo_grad = encoder(x, src_key_padding_mask=attn_mask)
print(f"Train mode multi-head attention: w/grad {mh_train_w_grad.shape} w/o {mh_train_wo_grad.shape}")
print(f"Train mode transformer encoder: w/grad {enc_train_w_grad.shape} w/o {enc_train_wo_grad.shape}")
print(f"Eval mode multi-head attention: w/grad {mh_eval_w_grad.shape} w/o {mh_eval_wo_grad.shape}")
print(f"Eval mode transformer encoder: w/grad {enc_eval_w_grad.shape} w/o {enc_eval_wo_grad.shape}")
Is this an expected behavior? Of course, I could have ensured the sizes before supplying the input, but this was not the behavior I was actually expecting, especially since they behave differently on train and eval modes. I will open an issue if this is indeed an unexpected behavior. Is there a catch I am missing?
ps: This is the case for torch==2.0.0
and torch==1.13.0
on macos. When I try the same script with torch==1.12.0
it produces the output that I expect.