Is nn.MultiheadAttention in pytorch just a linear transformation layer, while nn.TransformerEncoderLayer the combination of nn.MultiheadAttention and a feed forward layer?
In addition, here is an example of nn.TransformerEncoderLayer:
(out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
(linear1): Linear(in_features=128, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=512, out_features=128, bias=True)
(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
My question is, does the nonlinearity only come from LayerNorm and dropout layers? I did not see any activation function in this block.
Isn’t there a softmax function within the MultiHeadAttention
nn.Module? That’s a non-linear function.
Hi thanks for your reply but I could not find any activation functions defined within MultiheadAttention — PyTorch 1.13 documentation
After a brief look through the source code,
multi_head_attention_forward which contains a
softmax call as seen in the source here. So, it does seem to have a non-linear function of some kind. @ptrblck (apologizes for the tag) will be able to confirm any more specific details! But the source does show a non-linear function being used.
Thanks again. I checked the source too and found there is a linear layer in MultiheadAttention function that maps derived attention output in embed_dim to embed_dim at the end:
pytorch/functional.py at 0a274c4b6c916363ce3e3f75b315ac66156f8ce6 · pytorch/pytorch · GitHub
And in MultiheadAttention module which calls the function, self.out_proj is defined as the following:
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
So I guess there is an extra linear transformation before pytorch outputs the attention than what I originally expected.
@ptrblck @AlphaBetaGamma96 May you have a look, thanks!