nn.TransformerEncoderLayer mismatch on batch size dimension?

In the forward function of nn.TransformerEncoderLayer, the input goes through MultiheadAttention, followed by Dropout, then LayerNorm. According to the documentation, the input-output shape of MultiheadAttention is (S, N, E) → (T, N, E) where S is the source sequence length, L is the target sequence length, N is the batch size, E is the embedding dimension. The input-output shape of LayerNorm is (N, *) → (N, *). Wouldn’t this cause a problem because the batch size of the MultiheadAttention output is the second dimension, while LayerNorm expects the batch size to be in the first dimension? Or am I missing something? Thanks!

From LayerNorm — PyTorch 2.1 documentation,

If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.

So, TLDR it doesn’t care if seq len and batch size are permuted as long as the last dim is correct.

In the case of nn.TransformerEncoderLayer (below I do not distinguish between embed_dim and d_model in the case of internal operations):

  • if batch_first=True, then the input of the model will be of shape (batch_size, seq_len, dim) : here N of nn.LayerNorm corresponds to batch_size, * to (seq_len, dim) and normalized_shape to dim
  • otherwise, the input of the model will be of shape (seq_len, batch_size, dim) : here N of nn.LayerNorm corresponds to seq_len, * to (batch_size, dim) and normalized_shape to dim

But in all cases, the normalization is done along the last dimension (dim), because according to the documentation of nn.LayerNorm, when a single integer is used for normalized_shape, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.

There is no need to talk about target length in the case of the encoder, we have only one input of length seq_len / source_seq_len. It is in the case of the nn.TransformerDecoderLayer that we can even talk about T and S (target length and source length). And in this case an nn.LayerNorm will be applied to (target_seq_len, batch_size, dim) (or (batch_size, target_seq_len, dim) for batch_first=True) and others to (source_seq_len, batch_size, dim) (or (batch_size, source_seq_len, dim) for batch_first=True), always following the last dimension, no matter what value batch_first has.

See :

Thank you both, that clarifies my confusion. For nn.LayerNorm, it doesn’t matter whether the batch size is in the first or second index for a 3D tensor, because it acts on the last dimension.

A related follow-up question to clarify my understanding of nn.MultiheadAttention: for nn.MultiheadAttention, unlike nn.LayerNorm, it does matter whether the batch size is in the first or second index, right? It doesn’t just act on the last dimension?

In fact all the computations at the level of nn.MultiheadAttention are done with this shape (seq_len, batch_size, dim), that is batch_first=False.
Below, seq_len corresponds to source_seq_len in the case of the query, and to target_seq_len in the case of the key and value.

As you can see here, when batch_first=True, i.e. the entries (query, key, value) are of the shape (batch_size, seq_len, dim), they are first transposed into (seq_len, batch_size, dim) before the computations, and then the result (attn_output), of shape (seq_len, batch_size, dim), is re-transposed to (batch_size, seq_len, dim) before being returned.

That makes sense, thank you!