class MultiHeadAttention(nn.Module):
def __init__(self, embed_size:int, n_heads:int) -> None:
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.head_dim = embed_size//n_heads
assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
self.query = nn.Linear(embed_size, embed_size, bias=False)
self.key = nn.Linear(embed_size, embed_size, bias=False)
self.value = nn.Linear(embed_size, embed_size, bias=False)
self.linear = nn.Linear(embed_size, embed_size)
def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
batch_size, seq_len, embed_size = seq_embed.shape
query = self.query(seq_embed)
key = self.key(seq_embed)
value = self.value(seq_embed)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
if mask is not None:
scores.masked_fill(mask==0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
attention = torch.matmul(weights, value)
return self.linear(attention)
attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size:int, n_heads:int) -> None:
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.head_dim = embed_size//n_heads
assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
self.query = nn.Linear(embed_size, embed_size, bias=False)
self.key = nn.Linear(embed_size, embed_size, bias=False)
self.value = nn.Linear(embed_size, embed_size, bias=False)
self.linear = nn.Linear(embed_size, embed_size)
def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
batch_size, seq_len, embed_size = seq_embed.shape
query = self.query(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
key = self.key(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
value = self.value(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
if mask is not None:
scores.masked_fill(mask==0, float("-inf"))
weights = torch.softmax(scores, dim=-1)
attention = torch.matmul(weights, value)
attention = attention.reshape(batch_size, seq_len, self.n_heads*self.head_dim)
return self.linear(attention)
attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]