What does increasing number of heads do in the Multi-head Attention?

can someone explain to me the point of number of heads in the MultiheadAttention?
what happens if I increase or decrease them? would it change the number of learnable parameters?
what is the intuition behind increasing or decreasing the number of heads in the MultiheadAttention?

I’ll try the intuition part…

You can think all the heads like a panel of people, in such a way that each head is a different person, it has its own thoughts and view of the situation (the head’s weights).
So each person give his output, and then there is a leader, that takes into account all the outputs of the panel, and gives out the final verdict, that leader is the final feed forward part of the multi head, it concatenates all the outputs from the heads, and feed it to a linear layer to produce final output.

Adding more heads will add more parameters.

As a side note, more heads does not mean better model, it’s a hyper parameter, and depends on the challenge.

Roy.

This part in not correct I think, embed_dim is split into num_heads groups. So, parameter shapes are the same, but these groups are processed independently using reshaping (source).

@googlebot
Sorry you are correct, the pytorch implementation (following “attention is all you need paper”) will have the same paramaeter count regardless of num heads.

Just to note, there are other types of implementations of MultiHeadAttention where parameters amount scales with the number of heads.

Roy

@RoySadaka @googlebot Thanks for the help.
hmmmm so large or small number of heads cannot specify better or worse generalization?

More heads might get you more generalization, and I suggest you try it out, but there’s a chance it will not yield better results.

For example (true story)
I’ve created a model that uses 4 heads and adding more heads actually degraded the accuracy, tested both in pytorch implementation and in another implementation (that adds more parameters for more heads).
Also reducing heads hurts accuracy, so 4 is the magic number for my model and data.

Num heads is a parameter that you need to explore that fits best to the problem you try to solve.

Roy

1 Like