nn.TransfromerEncoder input shape

Hi, I am building a sequence to sequence model using nn.TransformerEncoder and I am not sure the shapes of my inputs are correct. The nn.Transformer documentation states that the input of the model should be (sequence_length, batch_size, embedding_dim). There is no details of the shapes in the nn.TransformerEncoder documentation. After looking at the pytorch seq2seq with transformer example, it seemed that the expected input shape is indeed (sequence_length, batch_size, embedding_dim).

However, I tried using both (sequence_length, batch_size, embedding_dim) and the more conventional (batch_size, sequence_length, embedding_dim) and both approaches seem to converge. Therefore my question is, what is the proper way of using this module ?

Here’s my model:

import torch
import torch.nn as nn
from .positional_encodings import PeriodicPositionalEncoding, OriginalPositionalEncoding
from .decoders import LinearDecoder

class TsTransformer(nn.Module):

    def __init__(self, in_features, out_features, sequence_length, d_model, num_heads, 
                 num_encoder_layers, num_decoder_layers, dim_feedforward, flatten_encoded, 
                 transformer_dropout=.1, decoder_dropout=.2, encoding_type='periodic',
                 decoder_type='linear'):
        
        super().__init__()

        # Linear mapping to model dimension
        self.fc1 = nn.Linear(in_features=in_features, out_features=d_model)
        # Positional encoding
        if encoding_type == 'original':
            self.positional_encoding = OriginalPositionalEncoding(d_model=d_model,
                                                                  dropout=transformer_dropout)
        elif encoding_type == 'periodic':
            self.positional_encoding = PeriodicPositionalEncoding(sequence_length=sequence_length,
                                                                  dropout=transformer_dropout)
        else:
            raise ValueError('Unknown encoding type: {}'.format(encoding_type))
        # Encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads,
                                                   dim_feedforward=dim_feedforward,
                                                   dropout=transformer_dropout)
        # Transformer encoder
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, 
                                             num_layers=num_encoder_layers)
        # Decoder
        if decoder_type == 'linear':
            self.decoder = LinearDecoder(d_model=d_model,
                                         sequence_length=sequence_length,
                                         out_features=out_features, 
                                         flatten=flatten_encoded,
                                         num_layers=num_decoder_layers,
                                         dropout=decoder_dropout) 
        else:
            raise ValueError('Unknown decoder type: {}'.format(decoder_type))
        
    def forward(self, inputs):
        '''
        inputs (batch_size, sequence_length, embedding_dim)
        '''
        # Map to model dim
        x = self.fc1(inputs)
        # Compute encodings and add
        x = x.permute(0,2,1)
        x = self.positional_encoding(x)
        # Transformer encode
        x = x.permute(2,0,1) # (sequence_length, batch_size, embedding_dim)
        x = self.encoder(x)
        x = x.permute(1,0,2) # back to (batch_size,  sequence_length, embedding_dim)
        # Decode
        x = self.decoder(x)

        return x

Thanks for your help :slight_smile: