Can't get Transformer to exceed LSTM - Help!

Hey there! This is my first post, so please be nice :slight_smile:

I work a lot with time series (forecasting and classification) and until now my go-to model has always been LSTM for ease of use. I recently started dabbling with Transformers and tried to employ them for some of the use-cases where I was successful with LSTMs but am struggling to get them to perform anywhere near as well.

Example LSTM:

One recent example is where I trained an LSTM to learn the parameters of a normal distribution conditioned on a time series. The corresponding model looked like so:

class ProbabilisticLSTM(nn.Module):
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            pos_weight: Tensor = None,
    ):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True)

        self.mean_predictor = nn.Linear(
            in_features=hidden_size,
            out_features=input_size
        )

        self.variances = nn.Linear(
            in_features=hidden_size,
            out_features=input_size
        )

        self.covariances = nn.Linear(
            in_features=hidden_size,
            out_features=torch.arange(input_size).sum().item()
        )

        self.softplus = nn.Softplus()

    @staticmethod
    def _create_scale_tril(variances, covariances) -> Tensor:
        covariance_matrix = torch.diag_embed(variances)

        input_size = variances.shape[-1]
        rows, cols = torch.tril_indices(input_size, input_size, offset=-1)

        covariance_matrix[:, rows, cols] = covariances
        return covariance_matrix

    def forward(self, x: PackedSequence | Tensor, last_hidden: tuple[Tensor, Tensor] = None) -> tuple[
        Distribution, tuple[Tensor, Tensor]]:
        _, (hn, cn) = self.lstm(x, last_hidden)
        x = hn.squeeze()

        means = self.mean_predictor(x)
        variances = self.softplus(self.variances(x))
        covariances = self.covariances(x)

        scale_tril = self._create_scale_tril(variances, covariances)

        return MultivariateNormal(loc=means, scale_tril=scale_tril), (hn, cn)

    def compute_forecast_loss(self, x: PackedSequence, y: Tensor):
        dist, _ = self.forward(x)
        return {
            "NLL": -dist.log_prob(y).mean(),
            "MSE": self.mse(dist.mean, y)
        }

Transfer to Transformer:

The above LSTM, despite being comparably simple, yielded useful results. I tried to design an analogous architecture using a transformer like so:

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class TransformerForecaster(nn.Module):
    _PADDING = -999999999

    def __init__(self,
                 n_features: int,
                 latent_dim: int,
                 n_heads: int,
                 dim_feedforward: int,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.n_heads = n_heads

        self.latent = nn.Linear(
            in_features=n_features,
            out_features=latent_dim
        )
        self.positional_encoding = PositionalEncoding(
            d_model=latent_dim,
            dropout=0,
        )
        self.transformer = nn.Transformer(
            d_model=latent_dim,
            nhead=n_heads,
            num_encoder_layers=1, num_decoder_layers=1,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )
        self.means = nn.Embedding(
            num_embeddings=n_features,
            embedding_dim=latent_dim,
        )
        self.variances = nn.Embedding(
            num_embeddings=n_features,
            embedding_dim=latent_dim,
        )
        self.covariances = nn.Embedding(
            num_embeddings=torch.arange(n_features).sum().item(),
            embedding_dim=latent_dim,
        )
        self.means_out = nn.Linear(in_features=latent_dim, out_features=1)
        self.variances_out = nn.Linear(in_features=latent_dim, out_features=1)
        self.covariances_out = nn.Linear(in_features=latent_dim, out_features=1)

        self.softplus = nn.Softplus()
        self.mse = nn.MSELoss()

    def pad_and_produce_masks(self, x: PackedSequence) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        x_padded, _ = pad_packed_sequence(x, batch_first=True, padding_value=self._PADDING)
        tensor_mask = x_padded == self._PADDING
        # tensor_mask = torch.isnan(x_padded)
        key_mask = torch.all(tensor_mask, dim=-1)
        attn_mask = key_mask.unsqueeze(dim=1).repeat((self.n_heads, key_mask.size(-1), 1))
        return x_padded, ~key_mask.unsqueeze(dim=-1), key_mask, attn_mask

    @staticmethod
    def _create_scale_tril(variances, covariances) -> Tensor:
        covariance_matrix = torch.diag_embed(variances)

        input_size = variances.shape[-1]
        rows, cols = torch.tril_indices(input_size, input_size, offset=-1)

        covariance_matrix[:, rows, cols] = covariances
        return covariance_matrix

    def forward(self, x: PackedSequence) -> Distribution:
        x, tensor_mask, key_mask, attn_mask = self.pad_and_produce_masks(x)

        x = self.latent(x)
        x = self.positional_encoding(x)

        means = self.transformer(
            src=x,
            tgt=self.means.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
            src_key_padding_mask=key_mask,
            src_mask=attn_mask,
            memory_key_padding_mask=key_mask,
            memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.means.num_embeddings, 1)),
        )
        variances = self.transformer(
            src=x,
            tgt=self.variances.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
            src_key_padding_mask=key_mask,
            src_mask=attn_mask,
            memory_key_padding_mask=key_mask,
            memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.variances.num_embeddings, 1)),
        )
        covariances = self.transformer(
            src=x,
            tgt=self.covariances.weight.unsqueeze(dim=0).repeat((x.size(0), 1, 1)),
            src_key_padding_mask=key_mask,
            src_mask=attn_mask,
            memory_key_padding_mask=key_mask,
            memory_mask=attn_mask.all(dim=1, keepdim=True).repeat((1, self.covariances.num_embeddings, 1)),
        )

        means = self.means_out(means).squeeze()
        variances = self.variances_out(variances).squeeze()
        variances = self.softplus(variances)
        covariances = self.covariances_out(covariances).squeeze()
        scale_tril = self._create_scale_tril(variances, covariances)

        return MultivariateNormal(loc=means, scale_tril=scale_tril)

    def compute_loss(self, x: PackedSequence, y: Tensor) -> dict:
        dist = self.forward(x)
        return {
            "NLL": -dist.log_prob(y).mean(),
            "MSE": self.mse(dist.mean, y)
        }

The Problem:

I expect the transformer’s performance to at least compare to that of the LSTM. However, when training the transformer it performs much worse than the LSTM and other much simpler predictors. Even with much simpler forecasting or classification tasks, I have never gotten the transformer to perform comparably to an LSTM. Also, the performance of the Transformer is not outlandish so as to suggest that something is very wrong - it will always be in the same ball park, but significantly worse.

Potential problems I have considered:

  1. Maybe the positional encoding is off? I am using the common positional encoding from Attention is All You Need that you find in every other tutorial. However, my data set consists of approx. 1000 different multivariate time series with anywhere between 2 and 30 steps. So, I was thinking that the encoding may be too granular (as in not having enough structure) as it seems it is designed for longer sequences.

  2. Maybe the masking is off? The time series all have varying lengths. For LSTMs there exists the wrapper class PackedSequence, which conveniently circumvents this being an issue. As far as I know there exists nothing analogous for transformers (please correct if wrong!). Hence, for batch learning I need to employ some form of masking logic (which, admittedly, I implemented very nastily - would be happy about suggestions for improvement!). I was wondering if somewhere along the way I misunderstood how the masking is supposed to be used or employed it in a wrong manner?

  3. Maybe my architecture is conceptually flawed? As can be seen above, I use embeddings for the means, variances and covariances of the distribution. In the forward pass I query from the embedding spaces using the input time series as source and the embedding weights as targets and project from latent space to single variables using linear layers. Is this conceptually sound? (I tried using a single embedding and linear layers for extracting means, variances and covariances as well, but to no avail.)

  4. Do I need more data? Is the problem that transformers are too “data-hungry” and only for larger data sets they would surpass an LSTM?

How to go on?

I am at my wit’s end. I tried changing the architecture, learning rate warm-up and scheduling, using many parameters, little parameters, lots of regularization, little regularization - nothing seems to do the trick. It is really hard to find resources on the internet discussing similar problems / prediction tasks, so I would be appreciative for any help! Thanks for reading through this post if you have made it this far! :slight_smile:

I think your problem is in your expectations when you say “I expect the transformer’s performance to at least compare to that of the LSTM.”. There are many many tasks where LSTM beats transformers, and even GRU beats LSTM. It really depends on the task and the amount of data you have.

Let’s pick two examples: Weather forecasting and language modelling.

Weather forecasting has a very simple solution - tomorrow’s weather will be the same as today’s. This is true of many analog autoregressive problems. What’s more, the influence of the past decays rapidly, something like exponentially. The RNN variants are very good at this, the exponential decay comes naturally.

Language modelling also has an exponential decay if you only have tens or hundreds of millions of words to train on. It’s only when you get into the internet scale of many billions that the very long range dependencies start to kick in.

We don’t know your task, but I really think that you are expecting too much out of transformers. You may even get something by going from LSTM to GRU - that’s my current research topic and it’s going very well.

Hey there! Thanks for the quick and insightful answer! I had the feeling that that LSTMs have an inherent advantage for my prediction task, as my data does indeed have strong autocorrelation and only short-range dependencies. Nevertheless, I am a little dumbfounded that I seem to be unable to get the transformer anywhere near the performance of the LSTM, as the general consensus seems to be that transformers outperform recurrent networks in most tasks - but maybe my impression is inaccurate. Out of curiosity; Would you happen to have literature at hand that explores examples where recurrent networks outperform transformers? I only know of one paper that challenges the idea that transformers are effective for time series forecasting but that one was refuted a little later due to flaws in their experimental design. Anyways, I will give GRUs a try and see how they handle my prediction task. :raised_hands:

I know, I get the same ‘general consensus’ argument all the time. I think it’s selection bias - it’s transformers that hit the headlines because throwing GPUs and data at a problem is effective - unless you don’t have huge budget in which case you don’t hit the headlines.

I’ve nearly 40 years experience with RNNs, probably more than anyone else still in the field. I have seen transformers fail to replace various RNN variants, but it wasn’t published.

Have you counted parameters and done ablation studies on smaller amounts of data and fewer parameters? You may be able to show a trend where if you had one or two orders if magnitude more data then transformers would win.

You asked for a reference. This paper just popped up in my feed today: [2409.17703] PGN: The RNN's New Successor is Effective for Long-Range Time Series Forecasting. I haven’t understood it, and there is no simple LSTM vs Transformer comparison, but the authors favourite variation on RNN beats the variation on transformers for the tasks they chose.