Hello,
The idea of Multi-head Attention is that we have multiple self-attention (attention heads) that is computed in parallel, and then later, those heads gets concatenated after.
This one is just self-attention with 1 head attention.
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size:int, n_heads:int) -> None:
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.head_dim = embed_size//n_heads
assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
self.query = nn.Linear(embed_size, embed_size, bias=False)
self.key = nn.Linear(embed_size, embed_size, bias=False)
self.value = nn.Linear(embed_size, embed_size, bias=False)
self.linear = nn.Linear(embed_size, embed_size)
def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
batch_size, seq_len, embed_size = seq_embed.shape
query = self.query(seq_embed)
key = self.key(seq_embed)
value = self.value(seq_embed)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
if mask is not None:
scores.masked_fill(mask==0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
attention = torch.matmul(weights, value)
return self.linear(attention)
This one is closer to the correct version, but you forgot to transpose the 1st dim (seq_len) and 2nd dim (n_heads) — So that each of our heads have the same seq_len x seq_len dimension of attention matrix. Then of course after the attention @ values, we must concatenate our heads together so that it will match the dimension of our input.
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size:int, n_heads:int) -> None:
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.head_dim = embed_size//n_heads
assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
self.query = nn.Linear(embed_size, embed_size, bias=False)
self.key = nn.Linear(embed_size, embed_size, bias=False)
self.value = nn.Linear(embed_size, embed_size, bias=False)
self.linear = nn.Linear(embed_size, embed_size)
def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
batch_size, seq_len, embed_size = seq_embed.shape
query = self.query(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
key = self.key(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
value = self.value(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
if mask is not None:
scores.masked_fill(mask==0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
attention = torch.matmul(weights, value)
attention = attention.reshape(batch_size, seq_len, self.n_heads*self.head_dim)
return self.linear(attention)
Here I wrote the correct version:
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size:int, n_heads:int) -> None:
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.head_dim = embed_size//n_heads
assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
self.query = nn.Linear(embed_size, embed_size, bias=False)
self.key = nn.Linear(embed_size, embed_size, bias=False)
self.value = nn.Linear(embed_size, embed_size, bias=False)
self.linear = nn.Linear(embed_size, embed_size)
def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
batch_size, seq_len, embed_size = seq_embed.shape
query = self.query(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)
key = self.key(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)
value = self.value(seq_embed).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1,2)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(query.size(-1))
if mask is not None:
scores.masked_fill(mask==0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
attention = torch.matmul(weights, value)
attention = attention.transpose(1,2).contiguous().view((batch_size, seq_len, embed_size))
return self.linear(attention)
You can check these implementations:
- Vision Transformer (ViT) - Attention
- minGPT - CausalSelfAttention
- BERT - BertSelfAttention
… since the solution I show you from above are just exact implementation from those links.
Let me know if this helps
Edit:
I update the scaler