Hey there! This is my first post, so please be nice
I work a lot with time series (forecasting and classification) and until now my goto model has always been LSTM for ease of use. I recently started dabbling with Transformers and tried to employ them for some of the usecases where I was successful with LSTMs but am struggling to get them to perform anywhere near as well.
Example LSTM:
One recent example is where I trained an LSTM to learn the parameters of a normal distribution conditioned on a time series. The corresponding model looked like so:
class ProbabilisticLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
pos_weight: Tensor = None,
):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
batch_first=True)
self.mean_predictor = nn.Linear(
in_features=hidden_size,
out_features=input_size
)
self.variances = nn.Linear(
in_features=hidden_size,
out_features=input_size
)
self.covariances = nn.Linear(
in_features=hidden_size,
out_features=torch.arange(input_size).sum().item()
)
self.softplus = nn.Softplus()
@staticmethod
def _create_scale_tril(variances, covariances) > Tensor:
covariance_matrix = torch.diag_embed(variances)
input_size = variances.shape[1]
rows, cols = torch.tril_indices(input_size, input_size, offset=1)
covariance_matrix[:, rows, cols] = covariances
return covariance_matrix
def forward(self, x: PackedSequence  Tensor, last_hidden: tuple[Tensor, Tensor] = None) > tuple[
Distribution, tuple[Tensor, Tensor]]:
_, (hn, cn) = self.lstm(x, last_hidden)
x = hn.squeeze()
means = self.mean_predictor(x)
variances = self.softplus(self.variances(x))
covariances = self.covariances(x)
scale_tril = self._create_scale_tril(variances, covariances)
return MultivariateNormal(loc=means, scale_tril=scale_tril), (hn, cn)
def compute_forecast_loss(self, x: PackedSequence, y: Tensor):
dist, _ = self.forward(x)
return {
"NLL": dist.log_prob(y).mean(),
"MSE": self.mse(dist.mean, y)
}
Transfer to Transformer:
The above LSTM, despite being comparably simple, yielded useful results. I tried to design an analogous architecture using a transformer like so:
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) > Tensor:
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerForecaster(nn.Module):
_PADDING = 999999999
def __init__(self,
n_features: int,
latent_dim: int,
n_heads: int,
dim_feedforward: int,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.n_heads = n_heads
self.latent = nn.Linear(
in_features=n_features,
out_features=latent_dim
)
self.positional_encoding = PositionalEncoding(
d_model=latent_dim,
dropout=0,
)
self.transformer = nn.Transformer(
d_model=latent_dim,
nhead=n_heads,
num_encoder_layers=1, num_decoder_layers=1,
dim_feedforward=dim_feedforward,
batch_first=True
)
self.means = nn.Embedding(
num_embeddings=n_features,
embedding_dim=latent_dim,
)
self.variances = nn.Embedding(
num_embeddings=n_features,
embedding_dim=latent_dim,
)
self.covariances = nn.Embedding(
num_embeddings=torch.arange(n_features).sum().item(),
embedding_dim=latent_dim,
)
self.means_out = nn.Linear(in_features=latent_dim, out_features=1)
self.variances_out = nn.Linear(in_features=latent_dim, out_features=1)
self.covariances_out = nn.Linear(in_features=latent_dim, out_features=1)
self.softplus = nn.Softplus()
self.mse = nn.MSELoss()
def pad_and_produce_masks(self, x: PackedSequence) > tuple[Tensor, Tensor, Tensor, Tensor]:
x_padded, _ = pad_packed_sequence(x, batch_first=True, padding_value=self._PADDING)
tensor_mask = x_padded == self._PADDING
# tensor_mask = torch.isnan(x_padded)
key_mask = torch.all(tensor_mask, dim=1)
attn_mask = key_mask.unsqueeze(dim=1).repeat((self.n_heads, key_mask.size(1), 1))
return x_padded, ~key_mask.unsqueeze(dim=1), key_mask, attn_mask
@staticmethod
def _create_scale_tril(variances, covariances) > Tensor:
covariance_matrix = torch.diag_embed(variances)
input_size = variances.shape[1]
rows, cols = torch.tril_indices(input_size, input_size, offset=1)
covariance_matrix[:, rows, cols] = covariances
return covariance_matrix
def forward(self, x: PackedSequence) > Distribution:
x, tensor_mask, key_mask, attn_mask = self.pad_and_produce_masks(x)
x = self.latent(x)
x = self.positional_encoding(x)
means = self.transformer(
src=x,
tgt=self.means.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
src_key_padding_mask=key_mask,
src_mask=attn_mask,
memory_key_padding_mask=key_mask,
memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.means.num_embeddings, 1)),
)
variances = self.transformer(
src=x,
tgt=self.variances.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
src_key_padding_mask=key_mask,
src_mask=attn_mask,
memory_key_padding_mask=key_mask,
memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.variances.num_embeddings, 1)),
)
covariances = self.transformer(
src=x,
tgt=self.covariances.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
src_key_padding_mask=key_mask,
src_mask=attn_mask,
memory_key_padding_mask=key_mask,
memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.covariances.num_embeddings, 1)),
)
means = self.means_out(means).squeeze()
variances = self.variances_out(variances).squeeze()
variances = self.softplus(variances)
covariances = self.covariances_out(covariances).squeeze()
scale_tril = self._create_scale_tril(variances, covariances)
return MultivariateNormal(loc=means, scale_tril=scale_tril)
def compute_loss(self, x: PackedSequence, y: Tensor) > dict:
dist = self.forward(x)
return {
"NLL": dist.log_prob(y).mean(),
"MSE": self.mse(dist.mean, y)
}
The Problem:
I expect the transformerâs performance to at least compare to that of the LSTM. However, when training the transformer it performs much worse than the LSTM and other much simpler predictors. Even with much simpler forecasting or classification tasks, I have never gotten the transformer to perform comparably to an LSTM. Also, the performance of the Transformer is not outlandish so as to suggest that something is very wrong  it will always be in the same ball park, but significantly worse.
Potential problems I have considered:

Maybe the positional encoding is off? I am using the common positional encoding from
Attention is All You Need
that you find in every other tutorial. However, my data set consists of approx. 1000 different multivariate time series with anywhere between 2 and 30 steps. So, I was thinking that the encoding may be too granular (as in not having enough structure) as it seems it is designed for longer sequences. 
Maybe the masking is off? The time series all have varying lengths. For LSTMs there exists the wrapper class
PackedSequence
, which conveniently circumvents this being an issue. As far as I know there exists nothing analogous for transformers (please correct if wrong!). Hence, for batch learning I need to employ some form of masking logic (which, admittedly, I implemented very nastily  would be happy about suggestions for improvement!). I was wondering if somewhere along the way I misunderstood how the masking is supposed to be used or employed it in a wrong manner? 
Maybe my architecture is conceptually flawed? As can be seen above, I use embeddings for the means, variances and covariances of the distribution. In the forward pass I query from the embedding spaces using the input time series as source and the embedding weights as targets and project from latent space to single variables using linear layers. Is this conceptually sound? (I tried using a single embedding and linear layers for extracting means, variances and covariances as well, but to no avail.)

Do I need more data? Is the problem that transformers are too âdatahungryâ and only for larger data sets they would surpass an LSTM?
How to go on?
I am at my witâs end. I tried changing the architecture, learning rate warmup and scheduling, using many parameters, little parameters, lots of regularization, little regularization  nothing seems to do the trick. It is really hard to find resources on the internet discussing similar problems / prediction tasks, so I would be appreciative for any help! Thanks for reading through this post if you have made it this far!