Wide Self Attention

In the current implementation of nn.MultiheadAttention, model_dim has to be equal to num_heads * kqv_dim, is it possible to break this limitation without reimplementing MHA? for example, I want the input and output of multi-head attention to be 512 dimentions, while having 24 heads with 64dim each head. And if that’s not possible, could someone recommend me a good implementation of multi-head attention?
I used this one https://github.com/pbloem/former/blob/master/former/modules.py to train a small BERT style language model on wikitext103, but it’s not to converging for some reason.

class MHSA(nn.Module):
  def __init__(self,
         emb_dim,
         kqv_dim,
         num_heads,
         residual=True,
         norm=True,
         dropout=0):
    super(MHSA, self).__init__()
    self.emb_dim = emb_dim
    self.kqv_dim = kqv_dim
    self.num_heads = num_heads
    self.residual = residual
    self.dropout = nn.Dropout(dropout)
    if norm:
      self.layer_norm = nn.LayerNorm(self.emb_dim)
    else:
      self.layer_norm = nn.Identity()

    self.w_k = nn.Linear(emb_dim, kqv_dim * num_heads)
    self.w_q = nn.Linear(emb_dim, kqv_dim * num_heads)
    self.w_v = nn.Linear(emb_dim, kqv_dim * num_heads)
    self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)

  def forward(self, x):
    b, t, _ = x.shape
    e = self.kqv_dim
    h = self.num_heads

    x = self.dropout(x)

    keys = self.w_k(x).view(b, t, h, e)
    values = self.w_v(x).view(b, t, h, e)
    queries = self.w_q(x).view(b, t, h, e)

    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)

    keys = keys / (e ** 0.25)
    queries = queries / (e ** 0.25)

    dot = queries @ keys.transpose(2, 3)
    dot = F.softmax(dot, dim=-1)

    out = dot @ values
    out = out.transpose(1,2).contiguous().view(b, t, h * e)
    out = self.w_out(out)
    if self.residual:
      out = out + x
    out = self.layer_norm(out)
    return out

Could someone point out where the problem is?