Questions about `torch.nn.MultiheadAttention`

I don’t understand how nn.MultiheadAttention module work.

  1. What is the meaning of k_dim and v_dim respectively in __init__? Are they related to key and value parameters in forward()?

  2. Why must the embedding_dim be divisible by num_heads? What is the meaning of head_dim derived from such a fomulation?

    I thought multi-head self-attention works this way:

  • embedding_dim are projected into query_dim, key_dim, value_dim respectively
  • after computing attention weights between query and key, the value_dim are collected according to the attention weights of each query, which forms exactly the output of current attention head.
  • Finally output of all heads are concatenated into one matrix of ~*head_num*value_dim and projected into embedding_dim. Thus, I don’t see where should the module divide embedding_dim by num_heads.
  1. What is the difference between torchtext.nn.MultiheadAttention and torch.nn.MultiheadAttention?

UP! I am very confused by this as well! Can anyone clarify?

Also: how can kdim be different from embed_dim (which is hardcoded to be qdim, which also makes no sense / is not very general)? In order to take the dot product between queries and keys, they have to have the same length! The only one between q,k,v that should be able to have a different length is v.

And vdim should also determine the size of the output, but then out_proj (the Matrix to projekt that output to the final output of the layer) is set up with the shape (embed_dim, embed_dim), when it should be (vdim, embed_dim), or better (vdim, out_dim) to be more general.

It seems that this layer can only work when embed_dim=kdim=vdim, which is the case any way when you do not specify kdim and vdim, so whats the point?

I have not yet tried to play around with it to see if it really doesn’t work when kdim or vdim are different from embed_dim, but just from looking at the code and the doc, this seems all a bit messy and not very well documented at all!

@namespace_pt @steffenN

Question1 : k_dim and v_dim are the dimensions of your inputs(same as key and value in self attention)
Question 2 : It must be divisible because. embedding_dim is divided across different heads. So if your embedding_dim = 300 and you have num_heads = 2. The first head words on 150 part of the embedding and the second head works on the other 150, the results of the two heads are later concatenated.

Please watch torch.nn.MultiheadAttention. It can clear all your doubts