Which Multihead Attention Implementation is Correct?

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed)
        key = self.key(seq_embed)
        value = self.value(seq_embed)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        return self.linear(attention)


attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size:int, n_heads:int) -> None:
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_dim = embed_size//n_heads
        assert(self.head_dim*n_heads==embed_size), "embed_size must be disvisible by n_heads"
        
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
    
    def forward(self, seq_embed:Tensor, mask:Tensor=None)->Tensor:
        batch_size, seq_len, embed_size = seq_embed.shape
        
        query = self.query(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        key = self.key(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)
        value = self.value(seq_embed).reshape(batch_size, seq_len, self.n_heads, self.head_dim)

        scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(embed_size)
        if mask is not None:
            scores.masked_fill(mask==0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        attention = torch.matmul(weights, value)
        attention = attention.reshape(batch_size, seq_len, self.n_heads*self.head_dim)
        return self.linear(attention)


attention = MultiHeadAttention(embed_size=768, n_heads=12)
seq = torch.zeros([32, 25, 768])
print(attention(seq).shape) #-> output shape [32, 25, 768]