My implementation of self attention

Hi everyone
I’ve implemented 2 slightly different versions of multihead self-attention. In my head they should be equivalent to each other, but they’re giving different outputs even if all the weights and inputs are the exact same. where is the problem? which one is correct?
v1 (modified from: https://github.com/pbloem/former/blob/master/former/modules.py):

class MHSA(nn.Module):
  def __init__(self,
         emb_dim,
         kqv_dim,
         num_heads=1):
    super(MHSA, self).__init__()
    self.emb_dim = emb_dim # dimention of input&output
    self.kqv_dim = kqv_dim  # dimention per head
    self.num_heads = num_heads # number of heads

    self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)

  def forward(self, x):
    b, t, _ = x.shape
    e = self.kqv_dim
    h = self.num_heads

    keys = self.w_k(x).view(b, t, h, e)
    values = self.w_v(x).view(b, t, h, e)
    queries = self.w_v(x).view(b, t, h, e)

    keys = keys.transpose(1, 2).contiguous().view(b * h, t, e)
    queries = queries.transpose(1, 2).contiguous().view(b * h, t, e)
    values = values.transpose(1, 2).contiguous().view(b * h, t, e)
    
    dot = torch.bmm(queries, keys.transpose(1, 2))
    dot = dot / np.sqrt(e)
    dot = F.softmax(dot, dim=2)
    out = torch.bmm(dot, values).view(b, h, t, e)
    out = out.transpose(1,2).contiguous().view(b, t, h * e)
    out = self.w_out(out)

v2:

class MHSA(nn.Module):
  def __init__(self,
         emb_dim,
         kqv_dim,
         num_heads=1):
    super(MHSA, self).__init__()
    self.emb_dim = emb_dim
    self.kqv_dim = kqv_dim
    self.num_heads = num_heads

    self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
    self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)

  def forward(self, x):

    b, t, _ = x.shape
    e = self.kqv_dim
    h = self.num_heads
    keys = self.w_k(x).view(b, t, h, e)
    values = self.w_v(x).view(b, t, h, e)
    queries = self.w_q(x).view(b, t, h, e)

    keys = keys.transpose(2, 1)
    queries = queries.transpose(2, 1)
    values = values.transpose(2, 1)

    dot = queries @ keys.transpose(3, 2)
    dot = dot / np.sqrt(e)
    dot = F.softmax(dot, dim=3)

    out = dot @ values
    out = out.transpose(1,2).contiguous().view(b, t, h * e)
    out = self.w_out(out)
    return out
1 Like

I can’t believe I made this silly mistake… in verson1 queries are outputted from w_v, instead of w_q.