Hi. I’m getting an odd error using a TransformerEncoder that only occurs when the model is in eval mode and I am using torch.no_grad(). Below is some simple code to reproduce the error on PyTorch version ‘1.12.0+cu102’. This error occurs when using cpu or gpu.
First, I initialize the encoder, some random input of shape (N, S, E), and some mask (N x num_heads, S, S) as indicated on nn.Transformer documentation.
d_model = 32
nhead = 4
dim_feedforward = 48
num_layers = 2
batch_size = 3
seq_length = 5
layer = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True, dim_feedforward=dim_feedforward)
xencoder = nn.TransformerEncoder(layer, num_layers=num_layers)
x = torch.rand(batch_size, seq_length, d_model)
mask = torch.zeros(batch_size * nhead, seq_length, seq_length, dtype=bool)
forward() works fine when the model is in training mode or when torch.no_grad() is not used.
xencoder.train()
print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])
xencoder.eval()
print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])
xencoder.train()
with torch.no_grad():
print(xencoder(x, mask=mask).shape)
# produces torch.Size([3, 5, 32])
However, when eval() and torch.no_grad() are used together, I receive a RuntimeError.
xencoder.eval()
with torch.no_grad():
print(xencoder(x, mask=mask).shape)
# produces RuntimeError: Mask shape should match input shape
Below is the full error output.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [54], in <cell line: 2>()
1 xencoder.eval()
2 with torch.no_grad():
----> 3 print(xencoder(x, mask=mask).shape)
File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/transformer.py:238, in TransformerEncoder.forward(self, src, mask, src_key_padding_mask)
236 output = mod(output, src_mask=mask)
237 else:
--> 238 output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
240 if convert_to_nested:
241 output = output.to_padded_tensor(0.)
File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/ebml2/lib/python3.10/site-packages/torch/nn/modules/transformer.py:437, in TransformerEncoderLayer.forward(self, src, src_mask, src_key_padding_mask)
417 tensor_args = (
418 src,
419 self.self_attn.in_proj_weight,
(...)
430 self.linear2.bias,
431 )
432 if (not torch.overrides.has_torch_function(tensor_args) and
433 # We have to use a list comprehension here because TorchScript
434 # doesn't support generator expressions.
435 all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]) and
436 (not torch.is_grad_enabled() or all([not x.requires_grad for x in tensor_args]))):
--> 437 return torch._transformer_encoder_layer_fwd(
438 src,
439 self.self_attn.embed_dim,
440 self.self_attn.num_heads,
441 self.self_attn.in_proj_weight,
442 self.self_attn.in_proj_bias,
443 self.self_attn.out_proj.weight,
444 self.self_attn.out_proj.bias,
445 self.activation_relu_or_gelu == 2,
446 False, # norm_first, currently not supported
447 self.norm1.eps,
448 self.norm1.weight,
449 self.norm1.bias,
450 self.norm2.weight,
451 self.norm2.bias,
452 self.linear1.weight,
453 self.linear1.bias,
454 self.linear2.weight,
455 self.linear2.bias,
456 src_mask if src_mask is not None else src_key_padding_mask,
457 )
458 x = src
459 if self.norm_first:
RuntimeError: Mask shape should match input shape
Looking at the source code, it seems like there is some fast compute path for MultiheadAttention when certain conditions are met, so I can’t tell if this is intended behavior or a bug. Should eval() and torch.no_grad() not be used simultaneously for Transformers? Or is there a different way to shape the mask tensor?
Any help would be appreciated. Thanks!