Seq2seq model with attention for time series forecasting

Hi, I’m putting together a basic seq2seq model with attention for time series forecasting. I can’t find any basic guide to achieve this, so I’m following this NLP tutorial. (NLP From Scratch: Translation with a Sequence to Sequence Network and Attention — PyTorch Tutorials 2.2.0+cu121 documentation) and trying to convert it to time series forecasting.

This is the time series model architecture I have now.

I have followed the architecture in the tutorial and removed all the embedding layers needed for NLP.
Following is the code.

class RNNEncoder(nn.Module):
    def __init__(self, rnn_num_layers=1, input_feature_len=1, sequence_len=168, hidden_size=100, bidirectional=False):
        super().__init__()
        self.sequence_len = sequence_len
        self.hidden_size = hidden_size
        self.input_feature_len = input_feature_len
        self.num_layers = rnn_num_layers
        self.rnn_directions = 2 if bidirectional else 1
        self.gru = nn.GRU(
            num_layers = rnn_num_layers,
            input_size=input_feature_len,
            hidden_size=hidden_size,
            batch_first=True,
            bidirectional=bidirectional
        )
        
    def forward(self, input_seq):
        ht = torch.zeros(self.num_layers * self.rnn_directions, input_seq.size(0) , self.hidden_size, device='cuda')
        if input_seq.ndim < 3:
            input_seq.unsqueeze_(2)
        gru_out, hidden = self.gru(input_seq, ht)
        if self.rnn_directions > 1:
            gru_out = gru_out.view(input_seq.size(0), self.sequence_len, self.rnn_directions, self.hidden_size)
            gru_out = torch.sum(gru_out, axis=2)
        return gru_out, hidden.squeeze(0)
    
    
class AttentionDecoderCell(nn.Module):
    def __init__(self, input_feature_len, hidden_size, sequence_len):
        super().__init__()
        # attention - inputs - (decoder_inputs, prev_hidden)
        self.attention_linear = nn.Linear(hidden_size + input_feature_len, sequence_len)
        # attention_combine - inputs - (decoder_inputs, attention * encoder_outputs)
        self.decoder_rnn_cell = nn.GRUCell(
            input_size=hidden_size,
            hidden_size=hidden_size,
        )
        self.out = nn.Linear(hidden_size, 1)
        
    def forward(self, encoder_output, prev_hidden, y):
        attention_input = torch.cat((prev_hidden, y), axis=1)
        attention_weights = F.softmax(self.attention_linear(attention_input)).unsqueeze(1)
        attention_combine = torch.bmm(attention_weights, encoder_output).squeeze(1)
        rnn_hidden = self.decoder_rnn_cell(attention_combine, prev_hidden)
        output = self.out(rnn_hidden)
        return output, rnn_hidden
    

class EncoderDecoderWrapper():
    def __init__(self, encoder, decoder_cell, output_size=3, teacher_forcing=0.3):
        super().__init__()
        self.encoder = encoder
        self.decoder_cell = decoder_cell
        self.output_size = output_size
        self.teacher_forcing = teacher_forcing
        
    def train(self):
        self.encoder.train()
        self.decoder_cell.train()
        
    def eval(self):
        self.encoder.eval()
        self.decoder_cell.eval()
        
    def state_dict(self):
        return {
            'encoder': self.encoder.state_dict(),
            'decoder_cell': self.decoder_cell.state_dict()
        }
    
    def load_state_dict(self, state_dict):
        self.encoder.load_state_dict(state_dict['encoder'])
        self.decoder_cell.load_state_dict(state_dict['decoder_cell'])

    def __call__(self, xb, yb=None):
        input_seq = xb
        encoder_output, encoder_hidden = self.encoder(input_seq)
        prev_hidden = encoder_hidden
        if torch.cuda.is_available():
            outputs = torch.zeros(input_seq.size(0), self.output_size, device='cuda')
        else:
            outputs = torch.zeros(input_seq.size(0), self.output_size)
        y_prev = input_seq[:, -1, :]
        for i in range(self.output_size):
            if (yb is not None) and (i > 0) and (torch.rand(1) < self.teacher_forcing):
                y_prev = yb[:, i].unsqueeze(1)
            rnn_output, prev_hidden = self.decoder_cell(encoder_output, prev_hidden, y_prev)
            y_prev = rnn_output
            outputs[:, i] = rnn_output.squeeze(1)
        return outputs

I’m looking to see if there are any black flags in the overall architecture. Does this implementation of global attention make sense for time series forecasting? Currently, I see a slight improvement in my results when compared to a vanilla RNN approach.

Specifically, in the NLP model a START_TOKEN is the first input to decoder, in this case I’m passing the last timestep from X sequence, is this alright?

The attention linear layer is reused in a loop for n future timesteps, does it make sense to maintain a separate layer for each timestep?

And let me know if I can look at any resources to make further improvements to the model.
Thanks :smiley:

1 Like

Interesting. But I don’t see a justification to use stale past states for forecasting, like in NLP tasks. Only capturing periodicities comes to mind, but this model is overkill for that problem.

Yes, but you don’t truly “compress” past data, so your model may learn simpler transition rules (e.g. AR(n)-like model).

Hi Alex, These were reasons for exploring this approach

  1. The model is used to forecast multiple time-series (around 10K time-series), sort of like predicting the sales of each product in each store. I don’t want the overhead of training multiple models, so deep learning looked like a good choice. This also gives me the freedom to add categorical data as embeddings.

  2. Initially, I tried to compress past data by only using the last hidden state and feed it into a Linear layer to forecast N timesteps, but this gave me poor results. Next, I tried to feed the results from all the GRU cells into a Linear layer to forecast N timesteps, this improved my results.

  3. I saw two drawbacks in the previous architecture, The forecast of N timesteps is now independent of each other, which is not ideal, this made me choose an encoder-decoder architecture. Now, each decoder cell does not need the output from each cell in the encoder, and to address this some sort attention mechanism was needed. And I agree that the attention mechanism ended up capturing the periodicity.

This is the plot of the attention weights the model learned. The model is paying attention to timesteps from the distant past too, this is inline with what I thought would happen.

One simplification I want to explore is to remove the attention layer, and just feed lagged timesteps to the decoder directly.

Let me know if I’m headed in the right direction.

You can transform last hidden state and make it initial state for another forecasting RNN. You may feed it some future time measure, last output (if you do one step at a time) or even nothing at all. That’s decoder too.

Yes, something like that.

About attention… Contrived example - if encoder captures speed & acceleration, there is no need to look back. If you want to replicate stationary series, maybe with some variations - yea, it may be easier with attention. But then it eats a lot of extra resources too…

I tried encoder-decoder model without attention like you have described and by passing my previous forecast as the input to the next decoder cell, but this didn’t give me an improvement over my previous model of feeding the results from all the encoder GRU cells to a Linear layer.

And I’m seeing an issue with passing the lags directly to the decoder too. The attention model is currently able to look into a window of values around the periodically important lag. For example, If I choose to pass t - 24 lag to my decoder, attention model does this better by giving high scores to t - 25, t - 24 and t - 26, with the peak at t - 24.

I can understand your point that using attention is resource-intensive, but provided I want to squeeze out the extra performance from the model, is this implementation of attention alright? Is there any other example of using attention for time-series forecasting?

It is difficult to compare these…

I haven’t noticed any coding/logic issues.

I’m sure arxiv.org has something on that :slight_smile:

1 Like

Actually, your architecture belongs to what Goodfellows book describes in “Explicit Memory” section. You made two particular design choices:

  1. attention weights, attention_combine and hidden state depend on each other. It is expressive and appropriate for NLP, but this makes your “cell” unparallelizable. I could imagine simple GRU producing attention weights without consulting “explicit memory”.
    And single 4d attention query after that…
    2)memory == encoder’s stepwise output. Interestingly, memory can be anything…
1 Like

Thanks Alex. I can understand your point on explicit memory. I’ll read up on this part.

I don’t completely understand this. My understanding of attention mechanism until now is that it looks at the output of past timesteps to determine the importance of each encoder output in determining the current decoder output. Is there an alternative? If it’s not looking at explicit memory, what will the attention weights be applied to?

And I’m exploring stuff on arxiv on this topic too, there are advancements in the attention mechanism itself, like adding convolutions to it - https://arxiv.org/pdf/1809.04206v3.pdf and thanks for validating the coding part too :+1:

What I meant, is that your decoder receives some initial hidden state, that can by itself express pretty complex dynamics. But you’re using more complex way to produce attention_weights, interleaving their generation and hidden state transitions. My theory is that as attention_weight are just query keys, you can generate them in isolation, i.e. without reading “explicit memory” - advantage is that this is doable with faster standard rnn cells. Well, that may be too different from what you’re doing…

@Gautham_Kumaran, I am also working on a model which should forecasting multiple time-series. (climatological study considering ~100k observations sites around the globe ). ideally I would also have the option to add categorical data as embeddings. However, I’m not there yet and currently still struggling with feeding the time series appropriately into the network.

Do you have a basic seq2seq example for multiple time series? I thought this would be a pretty straightforward task, but I’m struggling and cannot find any straightforward examples.

Many thanks in advance.

@sirolf I thought that this should be straightforward too, but wasn’t able to find any example implementation, so I wrote this article based on my experience working on this problem - Encoder-Decoder Model for Multistep Time Series Forecasting Using PyTorch, hope this helps.

Regarding adding categorical variables, there a multiple ways to do this, check this discussion - https://datascience.stackexchange.com/questions/17099/adding-features-to-time-series-model-lstm/17139#17139. I’ve tried all the methods mentioned there, implementation can be found here - https://github.com/gautham20/pytorch-ts/blob/master/ts_models/encoders.py. In my final model, I’ve OHE the categorical variables and fed it to the RNN directly.

2 Likes

Cheers, at a first glance it all appears very useful. I’ll have a detailed look now. Thanks for sharing!

Hi @Gautham_Kumaran, just saw your post at TowardsDataScience. Looks great!

However, I have a concern regarding your proposed architecture. In the diagram you shared, Assuming that t-1 is the last known step at inference time:

You are feeding the last observed target t-1 twice: once into the encoder and another time into the decoder.

This means that, essentially, your hidden state (or context vector, the yellow box called hidden in your diagram) already has info regarding t-1. However, you are re-feeding t-1 yet again to the decoder - which makes sense, given that you need to give something as decoder input to output the prediction for t. However, wouldn’t it be better just to only feed the encoder up to t-2?

This way you avoid giving essentially the same info to your decoder twice in the first decoder step, hence at all times the hidden state in both the encoder and the decoder is coherent content-wise, and contains the internal representation for all seen (or inferred) timesteps only once.

This is the way it’s done in the DeepAR paper for instance. The only change is that in DeepAR the encoder and the decoder are actually the same module, sharing exact architecture and weights.

With that said, I have also seen publications where the input for the decoder is always a copy of the encoder’s output vector as a constant for all decoder timesteps… So perhaps as long as your loss and training procedure are OK, the network learns to accomodate to whatever input you feed it…

What do you think?

2 Likes

Thanks @julioasoto

Passing upto t - 2 to the encoder makes total sense. Passing t - 1 twice was bugging me too, but I carried on for the lack of an alternative. If I get the chance to retrain this model, I’ll use this approach and update what I find.

I tried passing encoder output to each decoder timestep in the attention model. This was my takeaway from the experiment - if the data has a good seasonality or any good DateTime pattern, the attention mech. gives a negligible improvement over the basic seq2seq architecture (this was the case in the store item dataset), on the messy time-series dataset adding attention mechanism did provide a good improvement.

1 Like

Great comment! Can you give reference some of the publications you are reffering to? I’ve tried different type of inputs to the decoder which have, in general, given good results. However, I’d like to see different approaches that are well justified.

Yes, in my experience results look very similar with almost any sensible approach.

An example of this t-2 thing I commented above could be the DeepAR paper: https://arxiv.org/abs/1704.04110

Other people actually feed t-1 twice: once as the las timestep for the encoder and again as the first timestep for the decoder. This is the approach used by the winner of this Kaggle competition back in 2017: https://www.kaggle.com/c/web-traffic-time-series-forecasting/discussion/43795

As for feeding the encoder’s last hidden state as input for all decoder timesteps, the most popular example is in Guillaume Chevalier’s great tutorial on signal prediction with seq2seq models: https://github.com/guillaume-chevalier/seq2seq-signal-prediction

After trying these different approaches, my conclusion is that they pretty much all work and yield very similar results on a wide range of experiments and datasets (Airline Passengers dataset, m4 competition…). It looks like that as long as you have the appropiate number of weights and your loss makes sense, the network figures out the best way to deal with your architectural choice as long as it is sensible, like the three described in these paragraphs.

Hi @Gautham_Kumaran,

Thanks a lot for your work. I am using this as reference for one of my projects. However I have a query regarding the data we pass on to the encoder.

So imagine a case in which we have 500 series for predictions. my input seq length is 180 and output seq length is 150. So, the input data shape will be something like ( batch_size, 180, 500) right? and the output seq_length will be ( batch_size, 150, 500). (please correct me if I’m wrong here)

Now imagine a case where we have some external factors (F1, F2, F3) that effect the series and we have 1 year of data. I was thinking of reshaping the data to (365, 7, 500) where:

365: no of batches (each batch represents a day of each time series)
7: (F1, F2, F3, day, month, year, product_id)
500: no of time series

output shape: if you’re gonna predict for 1 day: (1, 1, 500),

where the second dimension is the sales with respect to those external factors(F1, F2, F3).

In this case, if we use an attention model, all external factors will be assigned proper weights in order to predict the sales data right? Do you think this would be a good approach? Please help me out.

Hello @Gautham_Kumaran, I I try to execute RNNConcatEncoder from your project but it shows the following error forward() missing 1 required positional argument: ‘input_cat’, input_cat is sequence_data[cat_columns]?