As you can see in the following snippet of pytorch code
when feeding a 3D attn_mask, the expected shape is (bsz * num_heads, tgt_len, src_len) and not simply (bsz, tgt_len, src_len).
Can I ask why?
It seems to me that we have a different way to treat 2D and 3D masks: since the former gets a new leading dimension meaning the same mask will be used for every batch, I would expected a 3D mask to have a leading dimension equal to the batch_size, regardless of heads number.
It is possible that I don’t see the real meaning behind this behavior…