Which Multihead Attention Implementation is Correct?

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed)
        key = self.key(seq_embed)
        value = self.value(seq_embed)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        return self.linear(attention)


attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        key = self.key(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        value = self.value(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        attention = attention.reshape(batch_size, seq_len, self.n_heads*self.head_dim)
        return self.linear(attention)


attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]

Hello,

The idea of Multi-head Attention is that we have multiple self-attention (attention heads) that is computed in parallel, and then later, those heads gets concatenated after.

This one is just self-attention with 1 head attention.

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed)
        key = self.key(seq_embed)
        value = self.value(seq_embed)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        return self.linear(attention)

This one is closer to the correct version, but you forgot to transpose the 1st dim (seq_len) and 2nd dim (n_heads) — So that each of our heads have the same seq_len x seq_len dimension of attention matrix. Then of course after the attention @ values, we must concatenate our heads together so that it will match the dimension of our input.

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        key = self.key(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        value = self.value(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        attention = attention.reshape(batch_size, seq_len, self.n_heads*self.head_dim)
        return self.linear(attention)

Here I wrote the correct version:

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query  = nn.Linear(embed_size, embed_size, bias=False)
        self.key    = nn.Linear(embed_size, embed_size, bias=False)
        self.value  = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query   =   self.query(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        key     =   self.key(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        value   =   self.value(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(query.size(-1))
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))
        weights = torch.softmax(scores, dim=-1)
        
        attention = torch.matmul(weights, value)
        attention = attention.transpose(1,2).contiguous().view((batch_size, seq_len, embed_size))
        return self.linear(attention)

You can check these implementations:

  1. Vision Transformer (ViT) - Attention
  2. minGPT - CausalSelfAttention
  3. BERT - BertSelfAttention

… since the solution I show you from above are just exact implementation from those links.

Let me know if this helps

Edit:
I update the scaler

Hi, thanks for the reply. Could you tell me if this is correct or wrong?

class SelfAttention(nn.Module):
    def __init__(self, hidden_size:int, head_dim:int):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(hidden_size, head_dim, bias=False)
        self.key = nn.Linear(hidden_size, head_dim, bias=False)
        self.value = nn.Linear(hidden_size, head_dim, bias=False)

    def forward(self, Q:Tensor, K:Tensor, V:Tensor, mask:Tensor=None)->Tensor:
        Q, K, V = self.query(Q), self.query(K), self.value(V)
        scores = torch.bmm(Q, K.transpose(1, 2))/math.sqrt(K.shape[-1])
        if mask is not None:
            scores = scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.bmm(weights, V)
        return attention

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size:int, n_heads:int):
        super(MultiHeadAttention, self).__init__()
        self.head_dim = hidden_size//n_heads
        assert(self.head_dim * n_heads==hidden_size),"head_dim * n_heads != embed_size"
        self.heads = nn.ModuleList([SelfAttention(hidden_size, self.head_dim) for _ in range (n_heads)])
        self.linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, Q:Tensor, K:Tensor, V:Tensor, mask:Tensor=None)->Tensor:
        concat = torch.cat([head(Q, K, V, mask) for head in self.heads], dim=-1)
        output = self.linear(concat)
        return output

Yes, this code is a more intuitive version that I sent you from above.Though this version executes sequentially because of the loop for _ in range (n_heads), you can use the one that I sent you since that one is faster.

You can debug or compare them by just using a smaller input / hyperparameter size like this :

num_heads = 2
batch_no = 2
window_size = 5
hidden_size = 4

Then access the weights of each Q,K and V by accessing .weight to manually compute each of the matrix multiply.

Regards