Multihead attention has an extra linear layer?

mha = nn.MultiheadAttention(512, 4, bias=False)
[p.shape for p in mha.parameters()]


[torch.Size([1536, 512]), torch.Size([512, 512]), torch.Size([512])]

So I see indices 1 and 2 as comprising an “extra” linear layer.

Following Attention Is All You Need, I wasn’t aware of the extra linear layer as part of multi-head attention.


Although they do have the blue Feed Forward box, I see that as a separate component.

So what’s going on in the PyTorch implementation? Is there an easy way to drop it?

Well, after some thinking about it, I realised I forgot to consider $W^O$ from section 3.2.2 in the paper! I believe that’s it right?