I was trying to create my own attention function for a project I’m working on. However, when I compared the output and weights from my code with those from torch.nn.MultiheadAttention
, I noticed that the softmax(QK^T/d_k^0.5)
is calculated incorrectly. Here is my code:
import torch
import torch.nn.functional as F
from torch.nn import MultiheadAttention
def attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k**0.5)
attn_output_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_output_weights, V)
return attn_output, attn_output_weights
embed_dim = 8
num_heads = 1
batch_size = 2
seq_len = 5
Q = torch.randn(batch_size, seq_len, embed_dim)
K = torch.randn(batch_size, seq_len, embed_dim)
V = torch.randn(batch_size, seq_len, embed_dim)
multihead_attn = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
attn_output_pytorch, attn_output_weights_pytorch = multihead_attn(Q, K, V)
attn_output_custom, attn_output_weights_custom = attention(Q, K, V)
assert torch.allclose(attn_output_custom, attn_output_pytorch, rtol=1e-6, atol=1e-8), "Attention output does not match."
assert torch.allclose(attn_output_weights_custom, attn_output_weights_pytorch, rtol=1e-6, atol=1e-8), "Attention weights do not match."
I tried changing the hyperparameters, printing each matrix, not normalizing by the d_k^0.5 factor, matching with torch.nn.functional.scaled_dot_product_attention
, and checking the shape of each tensor, but I still didn’t get good results. I am primarily concerned with matching attn_output_weights_custom
and attn_output_weights_pytorch
.
Can someone spot what I might be doing wrong?