MultiheadAttention module - How do I have to set embed_dim, kdim, vdim?

I tried to set k_dim differently than embed_dim, but I got an error stating they should be the same!? I am wondering how one needs to set the values in the module, what’s the point if you can’t set them differently?

I looked at the source code, now I am wondering if I misunderstand the implementation or if there is an error in the implementation? The dimension for the q embedding should be the same as kdim in my opinion (in red), as q and v are used for calculating the correlations. Why is it set to embed_dim?

I don’t get an error, which would state that kdim and vdim should be equal to the embed_dim as seen here:

embed_dim = 10
num_heads = 2

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

L, S, N, E = 2, 3, 4, embed_dim
query = torch.randn(L, N, E)
key = torch.randn(S, N, E)
value = torch.randn(S, N, E)
attn_output, attn_output_weights = multihead_attn(query, key, value)


# with different kdim, vdim
kdim = 5
vdim = 6
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, kdim=kdim, vdim=vdim)

L, S, N, E = 2, 3, 4, embed_dim
query = torch.randn(L, N, E)
key = torch.randn(S, N, kdim)
value = torch.randn(S, N, vdim)
attn_output, attn_output_weights = multihead_attn(query, key, value)

Could you post a code snippet to reproduce this issue, as I might misunderstand it?

Thanks for your answer. I think there is a misunderstanding from my side: I expected the module (nn.MultiheadAttention) to embed q, k, v values inside the module according to kdim, vdim and ideally qdim=kdim. The way the module is implemented, unfortunately it is not possible to use different embedding dimensions for the attention operation.

In my opinion it would be much more useful to be able to perform attention in individual dimensions for kdim and vdim. Is there a function or module that only performs attention without internal embedding stage?