No in my error I do have the good dimension for the mask, it’s the “input” dimension, which is derived by the module MultiheadAttention, I can’t modify it there.
Furthermore if I can make a suggestion I feel this is less intuitive for reading to do (batch_size * num_heads,…) also since 95% of the time you apply same mask on all heads it would make more sense the other way cause you could just broadcast over the dimension 1 that corresponds to heads:
like with a mask (batch_size, 1, source, target)