In the current implementation of nn.MultiheadAttention, model_dim has to be equal to num_heads * kqv_dim, is it possible to break this limitation without reimplementing MHA? for example, I want the input and output of multi-head attention to be 512 dimentions, while having 24 heads with 64dim each head. And if that’s not possible, could someone recommend me a good implementation of multi-head attention?
I used this one https://github.com/pbloem/former/blob/master/former/modules.py to train a small BERT style language model on wikitext103, but it’s not to converging for some reason.
class MHSA(nn.Module):
def __init__(self,
emb_dim,
kqv_dim,
num_heads,
residual=True,
norm=True,
dropout=0):
super(MHSA, self).__init__()
self.emb_dim = emb_dim
self.kqv_dim = kqv_dim
self.num_heads = num_heads
self.residual = residual
self.dropout = nn.Dropout(dropout)
if norm:
self.layer_norm = nn.LayerNorm(self.emb_dim)
else:
self.layer_norm = nn.Identity()
self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads)
self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads)
self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads)
self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)
def forward(self, x):
b, t, _ = x.shape
e = self.kqv_dim
h = self.num_heads
x = self.dropout(x)
keys = self.w_k(x).view(b, t, h, e)
values = self.w_v(x).view(b, t, h, e)
queries = self.w_q(x).view(b, t, h, e)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
keys = keys / (e ** 0.25)
queries = queries / (e ** 0.25)
dot = queries @ keys.transpose(2, 3)
dot = F.softmax(dot, dim=-1)
out = dot @ values
out = out.transpose(1,2).contiguous().view(b, t, h * e)
out = self.w_out(out)
if self.residual:
out = out + x
out = self.layer_norm(out)
return out
Could someone point out where the problem is?