[Possible Bug] Forcing 3D attn_mask to have wrong shape in MultiHeadAttention

Hi,
As you can see in the following snippet of pytorch code

MultiHeadAttention (pytorch/torch/nn/functional.py)

when feeding a 3D attn_mask, the expected shape is (bsz * num_heads, tgt_len, src_len) and not simply (bsz, tgt_len, src_len).
Can I ask why?

It seems to me that we have a different way to treat 2D and 3D masks: since the former gets a new leading dimension meaning the same mask will be used for every batch, I would expected a 3D mask to have a leading dimension equal to the batch_size, regardless of heads number.

It is possible that I don’t see the real meaning behind this behavior…
Thanks