Is nn.MultiheadAttention attn_mask working differently in pytorch 2.0?

    return torch._native_multi_head_attention(
RuntimeError: Mask shape should match input. mask: [2048, 172, 172] input: [128, 16, 172, 172]

they are talking to me about a problem of dimension but 2048 = batchsize * num_heads = 128 *16 (172=L, 172=S, self attention in my case)

According to the MultiheadAttention — PyTorch 2.0 documentation

It’s supposed to be of the right shape

Yes the number of elements is correct, but you need to match the shape exactly. You can do that by calling mask.flatten(0, 1).

No in my error I do have the good dimension for the mask, it’s the “input” dimension, which is derived by the module MultiheadAttention, I can’t modify it there.

Furthermore if I can make a suggestion I feel this is less intuitive for reading to do (batch_size * num_heads,…) also since 95% of the time you apply same mask on all heads it would make more sense the other way cause you could just broadcast over the dimension 1 that corresponds to heads:
like with a mask (batch_size, 1, source, target)

Hmm do you have a short self-contained snippet for reproduction?

Hi, I’ve just encountered the exact same issue, have you found a way to resolve?