In the forward function of nn.TransformerEncoderLayer, the input goes through MultiheadAttention, followed by Dropout, then LayerNorm. According to the documentation, the input-output shape of MultiheadAttention is (S, N, E) → (T, N, E) where S is the source sequence length, L is the target sequence length, N is the batch size, E is the embedding dimension. The input-output shape of LayerNorm is (N, *) → (N, *). Wouldn’t this cause a problem because the batch size of the MultiheadAttention output is the second dimension, while LayerNorm expects the batch size to be in the first dimension? Or am I missing something? Thanks!
If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
So, TLDR it doesn’t care if seq len and batch size are permuted as long as the last dim is correct.
In the case of
nn.TransformerEncoderLayer (below I do not distinguish between
d_model in the case of internal operations):
batch_first=True, then the input of the model will be of shape
(batch_size, seq_len, dim): here
- otherwise, the input of the model will be of shape
(seq_len, batch_size, dim): here
But in all cases, the normalization is done along the last dimension (
dim), because according to the documentation of
nn.LayerNorm, when a single integer is used for
normalized_shape, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
There is no need to talk about target length in the case of the encoder, we have only one input of length
source_seq_len. It is in the case of the nn.TransformerDecoderLayer that we can even talk about
S (target length and source length). And in this case an
nn.LayerNorm will be applied to
(target_seq_len, batch_size, dim) (or
(batch_size, target_seq_len, dim) for
batch_first=True) and others to
(source_seq_len, batch_size, dim) (or
(batch_size, source_seq_len, dim) for
batch_first=True), always following the last dimension, no matter what value
- TransformerEncoderLayer : pytorch/transformer.py at master · pytorch/pytorch · GitHub
- TransformerDecoderLayer : pytorch/transformer.py at master · pytorch/pytorch · GitHub
Thank you both, that clarifies my confusion. For nn.LayerNorm, it doesn’t matter whether the batch size is in the first or second index for a 3D tensor, because it acts on the last dimension.
A related follow-up question to clarify my understanding of nn.MultiheadAttention: for nn.MultiheadAttention, unlike nn.LayerNorm, it does matter whether the batch size is in the first or second index, right? It doesn’t just act on the last dimension?
In fact all the computations at the level of nn.MultiheadAttention are done with this shape
(seq_len, batch_size, dim), that is
seq_len corresponds to
source_seq_len in the case of the query, and to
target_seq_len in the case of the key and value.
As you can see here, when
batch_first=True, i.e. the entries (
value) are of the shape
(batch_size, seq_len, dim), they are first transposed into
(seq_len, batch_size, dim) before the computations, and then the result (
attn_output), of shape
(seq_len, batch_size, dim), is re-transposed to
(batch_size, seq_len, dim) before being returned.
That makes sense, thank you!