What does increasing number of heads do in the Multi-head Attention?

can someone explain to me the point of number of heads in the MultiheadAttention?
what happens if I increase or decrease them? would it change the number of learnable parameters?
what is the intuition behind increasing or decreasing the number of heads in the MultiheadAttention?

I’ll try the intuition part…

You can think all the heads like a panel of people, in such a way that each head is a different person, it has its own thoughts and view of the situation (the head’s weights).
So each person give his output, and then there is a leader, that takes into account all the outputs of the panel, and gives out the final verdict, that leader is the final feed forward part of the multi head, it concatenates all the outputs from the heads, and feed it to a linear layer to produce final output.

Adding more heads will add more parameters.

As a side note, more heads does not mean better model, it’s a hyper parameter, and depends on the challenge.

Roy.

This part in not correct I think, embed_dim is split into num_heads groups. So, parameter shapes are the same, but these groups are processed independently using reshaping (source).

2 Likes

@googlebot
Sorry you are correct, the pytorch implementation (following “attention is all you need paper”) will have the same paramaeter count regardless of num heads.

Just to note, there are other types of implementations of MultiHeadAttention where parameters amount scales with the number of heads.

Roy

@RoySadaka @googlebot Thanks for the help.
hmmmm so large or small number of heads cannot specify better or worse generalization?

More heads might get you more generalization, and I suggest you try it out, but there’s a chance it will not yield better results.

For example (true story)
I’ve created a model that uses 4 heads and adding more heads actually degraded the accuracy, tested both in pytorch implementation and in another implementation (that adds more parameters for more heads).
Also reducing heads hurts accuracy, so 4 is the magic number for my model and data.

Num heads is a parameter that you need to explore that fits best to the problem you try to solve.

Roy

1 Like

@RoySadaka I find it very surprising that you find 4 to be the optimal number of attention heads in the two very different implementations of multihead-attention, one which splits the tokens’ embedding dimension (and thereby keeps the number of learnable parameters fixed with respect to number of attention heads) and the other which does not split the tokens’ embedding dimensions (and thereby increases the number of learnable parameters as the number of attention heads is increased). Are you sure??? And, if so have you tried digging into why this may be happening?

Looking forward to hearing from anyone with an opinion as to how this could happen (discounting chance) given that the @RoySadaka comment was two years ago.