Understanding nn.MultiheadAttention

I tried to understand the multihead attention implementation, and tried the following:

embed_dim, num_heads  = 8, 2
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0, bias=False, add_bias_kv=False, add_zero_attn=False)

seq_len = 2
x = torch.rand(seq_len, embed_dim)

# Self-attention: Reference calculations
attn_output, attn_output_weights=mha(x, x, x)

# My manual calculations
wq, wk, wv = torch.split(mha.in_proj_weight, [embed_dim, embed_dim, embed_dim], dim=0)

q = torch.matmul(x, wq)
k = torch.matmul(x, wk)
v = torch.matmul(x, wv)

dk = embed_dim // num_heads
attention_map_manual = torch.matmul(q, k.transpose(0, 1)) / (math.sqrt(dk))
attention_map_manual = attention_map_manual.softmax(dim=1)

torch.allclose(attention_map_manual, attn_output_weights, atol=1e-4) # -> returns false

Why it returns zero? What is wrong with my calculations?

1 Like

Ok, I figured it out by looking at the source code. To anyone who wants to understand the weights and calculations in the multi-head attention, here is a simple gist