Questions about `torch.nn.MultiheadAttention`

I don’t understand how nn.MultiheadAttention module work.

  1. What is the meaning of k_dim and v_dim respectively in __init__? Are they related to key and value parameters in forward()?

  2. Why must the embedding_dim be divisible by num_heads? What is the meaning of head_dim derived from such a fomulation?

    I thought multi-head self-attention works this way:

  • embedding_dim are projected into query_dim, key_dim, value_dim respectively
  • after computing attention weights between query and key, the value_dim are collected according to the attention weights of each query, which forms exactly the output of current attention head.
  • Finally output of all heads are concatenated into one matrix of ~*head_num*value_dim and projected into embedding_dim. Thus, I don’t see where should the module divide embedding_dim by num_heads.
  1. What is the difference between torchtext.nn.MultiheadAttention and torch.nn.MultiheadAttention?

UP! I am very confused by this as well! Can anyone clarify?

Also: how can kdim be different from embed_dim (which is hardcoded to be qdim, which also makes no sense / is not very general)? In order to take the dot product between queries and keys, they have to have the same length! The only one between q,k,v that should be able to have a different length is v.

And vdim should also determine the size of the output, but then out_proj (the Matrix to projekt that output to the final output of the layer) is set up with the shape (embed_dim, embed_dim), when it should be (vdim, embed_dim), or better (vdim, out_dim) to be more general.

It seems that this layer can only work when embed_dim=kdim=vdim, which is the case any way when you do not specify kdim and vdim, so whats the point?

I have not yet tried to play around with it to see if it really doesn’t work when kdim or vdim are different from embed_dim, but just from looking at the code and the doc, this seems all a bit messy and not very well documented at all!

@namespace_pt @steffenN

Question1 : k_dim and v_dim are the dimensions of your inputs(same as key and value in self attention)
Question 2 : It must be divisible because. embedding_dim is divided across different heads. So if your embedding_dim = 300 and you have num_heads = 2. The first head words on 150 part of the embedding and the second head works on the other 150, the results of the two heads are later concatenated.

Please watch torch.nn.MultiheadAttention. It can clear all your doubts

Hi @AbdulsalamBande,

Thank you for your answer, I am also confused by point 2. I understand your answer, though, it doesn’t make sense to me that it works like that.

It is very weird to me to prevent a head from seeing parts of the features of the input. Going back to your example, you deprive each head from half of the feature of the input.

Extending your example, let’s say head 1 specializes in semantics and head 2 specializes in logic connectors in the sentence. If you give only half of the features to head 1, then it is difficult for the head to do its job because the split it receives may miss some important features from the input that contain a lot of information on semantics.
The same goes for head 2, the split it receives may miss some important features regarding logic.

I don’t understand how this architecture can be efficient, I am lacking the intuition here. What would make sense is that the split would happen AFTER the projection of Q, K and V. But then, we would not need embedding_dim to be divisible by num_heads.


Your transformer model has a certain dimensionality, like 512. Each “head” gets parts of that vector to hold it’s representation. So if you have 512 dimensionality vector representation, and 8 heads, each head gets 512/8 = 64 numbers to store it’s representation. The dimensionality is not about the features the head gets access to, it’s about the amount of numbers the heads gets to use to store its representation.