Hi everyone
I’ve implemented 2 slightly different versions of multihead self-attention. In my head they should be equivalent to each other, but they’re giving different outputs even if all the weights and inputs are the exact same. where is the problem? which one is correct?
v1 (modified from: https://github.com/pbloem/former/blob/master/former/modules.py):
class MHSA(nn.Module):
def __init__(self,
emb_dim,
kqv_dim,
num_heads=1):
super(MHSA, self).__init__()
self.emb_dim = emb_dim # dimention of input&output
self.kqv_dim = kqv_dim # dimention per head
self.num_heads = num_heads # number of heads
self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)
def forward(self, x):
b, t, _ = x.shape
e = self.kqv_dim
h = self.num_heads
keys = self.w_k(x).view(b, t, h, e)
values = self.w_v(x).view(b, t, h, e)
queries = self.w_v(x).view(b, t, h, e)
keys = keys.transpose(1, 2).contiguous().view(b * h, t, e)
queries = queries.transpose(1, 2).contiguous().view(b * h, t, e)
values = values.transpose(1, 2).contiguous().view(b * h, t, e)
dot = torch.bmm(queries, keys.transpose(1, 2))
dot = dot / np.sqrt(e)
dot = F.softmax(dot, dim=2)
out = torch.bmm(dot, values).view(b, h, t, e)
out = out.transpose(1,2).contiguous().view(b, t, h * e)
out = self.w_out(out)
v2:
class MHSA(nn.Module):
def __init__(self,
emb_dim,
kqv_dim,
num_heads=1):
super(MHSA, self).__init__()
self.emb_dim = emb_dim
self.kqv_dim = kqv_dim
self.num_heads = num_heads
self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads, bias=False)
self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)
def forward(self, x):
b, t, _ = x.shape
e = self.kqv_dim
h = self.num_heads
keys = self.w_k(x).view(b, t, h, e)
values = self.w_v(x).view(b, t, h, e)
queries = self.w_q(x).view(b, t, h, e)
keys = keys.transpose(2, 1)
queries = queries.transpose(2, 1)
values = values.transpose(2, 1)
dot = queries @ keys.transpose(3, 2)
dot = dot / np.sqrt(e)
dot = F.softmax(dot, dim=3)
out = dot @ values
out = out.transpose(1,2).contiguous().view(b, t, h * e)
out = self.w_out(out)
return out