TransformerEncoder mask not working with eval() and no_grad()

Hi. I’m getting an odd error using a TransformerEncoder that only occurs when the model is in eval mode and I am using torch.no_grad(). Below is some simple code to reproduce the error on PyTorch version ‘1.12.0+cu102’. This error occurs when using cpu or gpu.

First, I initialize the encoder, some random input of shape (N, S, E), and some mask (N x num_heads, S, S) as indicated on nn.Transformer documentation.

d_model = 32
nhead = 4
dim_feedforward = 48
num_layers = 2
batch_size = 3
seq_length = 5

layer = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True, dim_feedforward=dim_feedforward)
xencoder = nn.TransformerEncoder(layer, num_layers=num_layers)

x = torch.rand(batch_size, seq_length, d_model)
mask = torch.zeros(batch_size * nhead, seq_length, seq_length, dtype=bool)

forward() works fine when the model is in training mode or when torch.no_grad() is not used.

xencoder.train()
print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])
xencoder.eval()
print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])
xencoder.train()
with torch.no_grad():
    print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])

However, when eval() and torch.no_grad() are used together, I receive a RuntimeError.

xencoder.eval()
with torch.no_grad():
    print(xencoder(x, mask=mask).shape)
# produces RuntimeError: Mask shape should match input shape

Below is the full error output.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [54], in <cell line: 2>()
      1 xencoder.eval()
      2 with torch.no_grad():
----> 3     print(xencoder(x, mask=mask).shape)

File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/transformer.py:238, in TransformerEncoder.forward(self, src, mask, src_key_padding_mask)
    236         output = mod(output, src_mask=mask)
    237     else:
--> 238         output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
    240 if convert_to_nested:
    241     output = output.to_padded_tensor(0.)

File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/transformer.py:437, in TransformerEncoderLayer.forward(self, src, src_mask, src_key_padding_mask)
    417     tensor_args = (
    418         src,
    419         self.self_attn.in_proj_weight,
   (...)
    430         self.linear2.bias,
    431     )
    432     if (not torch.overrides.has_torch_function(tensor_args) and
    433             # We have to use a list comprehension here because TorchScript
    434             # doesn't support generator expressions.
    435             all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and
    436             (not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))):
--> 437         return torch._transformer_encoder_layer_fwd(
    438             src,
    439             self.self_attn.embed_dim,
    440             self.self_attn.num_heads,
    441             self.self_attn.in_proj_weight,
    442             self.self_attn.in_proj_bias,
    443             self.self_attn.out_proj.weight,
    444             self.self_attn.out_proj.bias,
    445             self.activation_relu_or_gelu == 2,
    446             False,  # norm_first, currently not supported
    447             self.norm1.eps,
    448             self.norm1.weight,
    449             self.norm1.bias,
    450             self.norm2.weight,
    451             self.norm2.bias,
    452             self.linear1.weight,
    453             self.linear1.bias,
    454             self.linear2.weight,
    455             self.linear2.bias,
    456             src_mask if src_mask is not None else src_key_padding_mask,
    457         )
    458 x = src
    459 if self.norm_first:

RuntimeError: Mask shape should match input shape

Looking at the source code, it seems like there is some fast compute path for MultiheadAttention when certain conditions are met, so I can’t tell if this is intended behavior or a bug. Should eval() and torch.no_grad() not be used simultaneously for Transformers? Or is there a different way to shape the mask tensor?

Any help would be appreciated. Thanks!

Hi. Does anyone know anything about why this may be happening?

It seems this might have been an issue in PyTorch 1.12.0 as I cannot reproduce it anymore in a source build from the current master branch. You thus might need to update to 1.12.1 (haven’t checked this version) or the nightly binary.

I just updated to 1.12.1 and it works now. Thank you!