I tried to understand the multihead attention implementation, and tried the following:
embed_dim, num_heads = 8, 2
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0, bias=False, add_bias_kv=False, add_zero_attn=False)
seq_len = 2
x = torch.rand(seq_len, embed_dim)
# Self-attention: Reference calculations
attn_output, attn_output_weights=mha(x, x, x)
# My manual calculations
wq, wk, wv = torch.split(mha.in_proj_weight, [embed_dim, embed_dim, embed_dim], dim=0)
q = torch.matmul(x, wq)
k = torch.matmul(x, wk)
v = torch.matmul(x, wv)
dk = embed_dim // num_heads
attention_map_manual = torch.matmul(q, k.transpose(0, 1)) / (math.sqrt(dk))
attention_map_manual = attention_map_manual.softmax(dim=1)
torch.allclose(attention_map_manual, attn_output_weights, atol=1e-4) # -> returns false
Why it returns zero? What is wrong with my calculations?