nn.TransformerEncoderLayer 3D Mask Doesn't Match the Broadcast Shape

Hello. I am trying to use a 3D mask with a nn.TransformerEncoderLayer. However, I can not seem to get the dimensions right. The nn.Transformer documentation dictates the source size as (source sequence length, batch size, feature number) which I use below. The MultiHeadAttention code shows the shape of a 3D mask should be "(N*num_heads, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length," which I also use below.

Reproducible Example Code:

import torch
from torch import nn

num_heads = 8
bs = 16
num_features = 768
src_and_target_size = 21

encoder_layer = nn.TransformerEncoderLayer(num_features, num_heads, dim_feedforward=2048, dropout=0.1)

src = torch.rand(src_and_target_size, bs, num_features)

mask = torch.ones(bs*num_heads, src_and_target_size, src_and_target_size)

print(src.shape)
print(mask.shape)

encoder_layer(src, mask)

However, I recieve the below error message. Can you please help me find the error in my mask dimension sizes?

Error Message:

Traceback (most recent call last):
  File "test.py", line 19, in <module>
    encoder_layer(src, mask)
  File "/home/user/anaconda3/envs/test_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user/anaconda3/envs/test_env/lib/python3.7/site-packages/torch/nn/modules/transformer.py", line 283, in forward
    key_padding_mask=src_key_padding_mask)[0]
  File "/home/user/anaconda3/envs/test_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user/anaconda3/envs/test_env/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 819, in forward
    attn_mask=attn_mask)
  File "/home/user/anaconda3/envs/test_env/lib/python3.7/site-packages/torch/nn/functional.py", line 3362, in multi_head_attention_forward
    attn_output_weights += attn_mask
RuntimeError: output with shape [128, 21, 21] doesn't match the broadcast shape [1, 128, 21, 21]

Thanks.

The PyTorch code that produces the error:

torch/nn/functional.py", line 3362, in multi_head_attention_forward:

if attn_mask is not None:
    attn_mask = attn_mask.unsqueeze(0)
    attn_output_weights += attn_mask

Commenting out the attn_mask.unsqueeze(0) line makes the code run without errors. Is this a bug in PyTorch? Why add a dimension to the beginning of the attention mask?

Solution: Upgrade to PyTorch 1.5