Hi, I am building a sequence to sequence model using nn.TransformerEncoder and I am not sure the shapes of my inputs are correct. The nn.Transformer documentation states that the input of the model should be (sequence_length, batch_size, embedding_dim). There is no details of the shapes in the nn.TransformerEncoder documentation. After looking at the pytorch seq2seq with transformer example, it seemed that the expected input shape is indeed (sequence_length, batch_size, embedding_dim).
However, I tried using both (sequence_length, batch_size, embedding_dim) and the more conventional (batch_size, sequence_length, embedding_dim) and both approaches seem to converge. Therefore my question is, what is the proper way of using this module ?
Here’s my model:
import torch
import torch.nn as nn
from .positional_encodings import PeriodicPositionalEncoding, OriginalPositionalEncoding
from .decoders import LinearDecoder
class TsTransformer(nn.Module):
def __init__(self, in_features, out_features, sequence_length, d_model, num_heads,
num_encoder_layers, num_decoder_layers, dim_feedforward, flatten_encoded,
transformer_dropout=.1, decoder_dropout=.2, encoding_type='periodic',
decoder_type='linear'):
super().__init__()
# Linear mapping to model dimension
self.fc1 = nn.Linear(in_features=in_features, out_features=d_model)
# Positional encoding
if encoding_type == 'original':
self.positional_encoding = OriginalPositionalEncoding(d_model=d_model,
dropout=transformer_dropout)
elif encoding_type == 'periodic':
self.positional_encoding = PeriodicPositionalEncoding(sequence_length=sequence_length,
dropout=transformer_dropout)
else:
raise ValueError('Unknown encoding type: {}'.format(encoding_type))
# Encoder layers
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=transformer_dropout)
# Transformer encoder
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
num_layers=num_encoder_layers)
# Decoder
if decoder_type == 'linear':
self.decoder = LinearDecoder(d_model=d_model,
sequence_length=sequence_length,
out_features=out_features,
flatten=flatten_encoded,
num_layers=num_decoder_layers,
dropout=decoder_dropout)
else:
raise ValueError('Unknown decoder type: {}'.format(decoder_type))
def forward(self, inputs):
'''
inputs (batch_size, sequence_length, embedding_dim)
'''
# Map to model dim
x = self.fc1(inputs)
# Compute encodings and add
x = x.permute(0,2,1)
x = self.positional_encoding(x)
# Transformer encode
x = x.permute(2,0,1) # (sequence_length, batch_size, embedding_dim)
x = self.encoder(x)
x = x.permute(1,0,2) # back to (batch_size, sequence_length, embedding_dim)
# Decode
x = self.decoder(x)
return x
Thanks for your help