Implementation of Multi-Head Attention

Hi. is this a right implementation of MHA that I create 3 Linear Layer for forecasting q, k, v? I think definition of 3 smaller layer is more stable than one large layer.

class MHA(nn.Module):
    def __init__(self, input_dim, embed_dim, h, dropout=0, args=None):
        super(SelfAttention, self).__init__()
        self.args = {} if args is None else args
        self.multihead_attn = nn.MultiheadAttention(embed_dim, h)
        self.dropout = nn.Dropout(dropout)
        self.q_proj = nn.Linear(input_dim, embed_dim)
        self.k_proj = nn.Linear(input_dim, embed_dim)
        self.v_proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        query = self.dropout(self.q_proj(x))
        key   = self.dropout(self.k_proj(x))
        value = self.dropout(self.v_proj(x))
        return self.multihead_attn(query, key, value, **self.args)

I have some other questions

  1. Why doesn’t nn.MultiheadAttention take ‘x’ and produce q, k, v itself?
  2. If number of heads is set to 1 in MHA module, then the self-attenuation is obtained, but in the documention of nn.MultiheadAttention it is written that if all the variables take the x value, the self-attenuation is calculated.

What do you mean by more stable here?
The three are strictly equivalent mathematically: If you concatenate the three weights and biases, and then send in the input x and split the output, you get the exact same result as if you keep them separate.
Computationally, it is often more efficient to have one large input rather than three small ones.

Best regards


if we assume that the linear layer maps from 256 to 64, in this case, a large layer maps from 256 to 192 thus its convergence take too long? i don’t sure about this. Even the Matrix that is formed in the GPU to perform calculations is smaller?
What do you think about questions 1 and 2?

I think it does have the linear layer you use here, so you can feed x, x, x if you want.
Note that for “external attention” in the decoder of encoder-decoder networks (e.g. in machine translation) you typically want to feed some inputs from the decoder and some from the encoder.

Self-Attention here is in contrast to external Attention, not in contrast to multi-head.

Best regards