LSTM time-series prediction

I’m using an LSTM to predict a time-seres of floats. I’m using a window of 20 prior datapoints (seq_length = 20) and no features (input_dim =1) to predict the “next” single datapoint. My network seems to be learning properly.

Here’s the observed data vs. predicted with the trained model:

Here’s a naive implementation of how to predict multiple steps ahead using the trained network:

data = timeseries[-20:] # Last observed data  (20 datapoints)
last_seq = data.reshape(seq_length,1,input_dim) # Batch size of 1, tensor of size (20,1,1)
last_seq = torch.from_numpy(last_seq).float() # pytorch tensor of floats
last_pred = model.forward(Variable(last_seq, requires_grad=False))
last_pred = last_pred.data.numpy()
data = np.append(data, last_pred) # This now contains 21 values, take the last 20 and repeat appropriately.

When I do this the predictions recurrently for the last 20 observations (which at some point include only predicted values) I get:

  1. What is the right way to use a trained LSTM to generate predictions for future timesteps?
  2. Are LSTM the right tool to predict sequences of floats? Will something like seq2seq allow me to predict sequences of floats from sequences of floats?
1 Like

Somewhat orthogonal to your question, looks like your time series is cyclical? Have you considered adding a couple of input features, that represent the cyclical time, ie/eg:

sin(time_of_day/24 * 2 * pi)
cos(time_of_day/24* 2* pi)

?

  1. Your network is trained to predict the next timestep given the previous, perfect timesteps (which form one kind of input distribution). Once trained in such a fashion, you should also train it to make predictions given its own predictions (which form another kind of input distribution which is probably a bit different from the real input distribution, hence the divergence). Not my area of expertise but you should be able to find code for this somewhere.
  2. LSTMs are fine - your problem is what I talked about above.
3 Likes

Yes definitely a strong cyclical component, I was thinking of using one hot vectors for time but your approach looks better, thanks for the suggestion.

Very interesting. So retrain the same model, or use it as input for a second model? Would you mind providing a rough pointer of where to look for the code you suggest? I searched and couldn’t find anything specific along these lines.

The PyTorch official tutorials have one which mentions this training (teacher forcing) and how to train against its own distribution: http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html. Note that even though this is seq2seq it is relevant for you as well.

2 Likes

I’ve seen teacher forcing thrown around quite a bit, thanks for pointing me in this direction.

1 Like

Regardless of the teacher forcing (re)training, does the implementation of prediction of future values I posted in the OP look fine to you?

Looks like a sensible way of feeding in the data to me, though it would be best to follow some PyTorch tutorials on using RNNs to see what the convention is.

1 Like

Are you using nn.LSTM by looks of it? U would want to use nn.LSTMCell for this

1 Like

Just out of curiosity, what is difference, and difference in use-case, between LSTM and LSTMCell?

Hi

I think LSTMCell = do your own timestep, LSTM = pass the entire sequence. Usecase: use LSTM unless you have a reason not to, reasons are things that LSTM/cudnn don’t support, e.g. attention, teacher forcing. The OpenNMT-py uses both in en- and decoder.

Best regards

Thomas

4 Likes

@dgriff @tom Looks like I’m not even using the right tool then :slight_smile: Thanks for the pointer.

LSTMCell is an unrolled lstm. And what you want here is a stateful lstm which is much more easily made using an LSTMCell

1 Like

Hi,
So, I had lots of trouble with this issue few months ago for my PhD.
To be exact, what I had is lots x, y of a pen trajectory (continuous data), drawing letters. I have all the alphabet recorded for 400 writers. What I wanted to let the LSTM (GRU in my case) generate letters. Straightforward? Not at all!

  1. I trained the model in prediction (predicting the next step), and then used it to generate the letters. I was inspired by the same approach Andrei Karpathy did in his blog
    http://karpathy.github.io/2015/05/21/rnn-effectiveness/
    The prediction result was perfect! However, when I generate (not again, it is continuous data. Karpathy use categorical data - characters -), it always flats out (literally saturating at some value very quickly). Whatever letter, for whatever writer style, always the result is shit!
  2. I then tried something very simple. Just learn (or remember technically) a continuous sine wave. That took 2 weeks to resolve (just to make the GRU memorize). Refer to the discussion in this thread
    LSTM time sequence generation
    and this question in stack overflow
    https://stackoverflow.com/questions/43459013/lstm-time-sequence-generation-using-pytorch
    The whole trick was to ‘increase’ the sequence length enough during the training, to enable the algorithm to capture the ‘whole’ pattern. Remember, this is just to memorize, not to ‘learn’.
    My conclusion was: discretize discretize disctretize
  3. I found this super awesome paper from Alex Graves
    https://arxiv.org/abs/1308.0850
    where he solves this problem (but on a different dataset). He uses something called Mixture Density Network. It is really a beautiful thing! He managed to handle continuous data neatly.
    I tried to replicate his architecture in pytorch (will, to be fair, I didn’t try hard enough), but it is very unstable during the training. I tried to stabilize in many ways, but I had to stop pushing in this direction shortly (you know, supervisors are moody ! )
  4. Following an advice in the thread mentioned in point 2 (to make the model sees its own output), I tried it (still believing that continuous data will work…). This is similar to what @Kaixhin suggested. Although the idea make sense, it is a big big issue on how to do it!
    In continuous domain, nothing happens (literally nothing). The results is still shit.
    To get some idea on how to do this, take a look at this paper - called Scheduled Sampling - (but the paper implements for discrete data - of course! -)
    https://arxiv.org/abs/1506.03099
    Then, this awesome guy came, proving that this scheme 'of feeding the model its own output, can lead to problems! (and if the model at any point of time sees just its own output, it will never learn the correct distribution!!)
    https://arxiv.org/abs/1511.05101
    So, the guys who did the paper of scheduled sampling did a new paper, recognizing this awesome guy neat paper, in order to remedy the problem. It is called 'professor forcing’
    https://arxiv.org/abs/1610.09038
    To be honest, I don’t like this paper very much, but I think it is a good step in the right direction (note again, they only use categorical data - finite letters or words -).
    In short, how to make the model take into own distribution error into account is still an open issue (to the best of my knowledge).
  5. Reaching huge amount of failure and frustration at this point, I decided to discretize the data. Instead of having x, y, I used another encoding (Freeman codes, no need to go to details).
    I used the architecture mentioned by Karphathy in his blog + ‘Show and tell’ technique to bias the model (since I’ve the same letters - labels - from different writers)
    https://arxiv.org/abs/1411.4555
    It worked neatly!! I am super happy with it (even though you lose important info when using Freeman codes, but we have fkg letters! (and many are really complex, and beautiful)
  6. I was feeling confidence at this point, so I decide to add ‘distance’ to be predicted with Freeman code - I discretized the shit out of it - (forgot about the details. Just imagine i want to predict two different random variables instead of just one). I don’t fuse the modalities. I have two softmax at the end, each predicting different random variable.
    That is super tricky! In prediction, it works perfectly. In generation, it is a complete shit!
    After some thinking, I came to realize the problem. Assume the LSTM output is h, and freeman codes variable is R1 and the distance variable is R2. When I train, I train to model P(R1 | h) and P(R2 | h), but not P(R1, R2 | h). When you sample like what Karpathy is doing, you use a multinomial distribution for each softmax. This lead the model to follow two different paths for R1 and R2, which lead to this shit (in short, with this scheme, you need P(R1, R2 | h). I am still working to solve this problem at the moment.

In short, my advice is (if you need something quick):

  1. Discretize (categorize) your data
  2. Use Karpathy approach (mentioned in his blog)
  3. If you need to bias the model, there are many techniques to do that. I tried Show and Tell. It is simple, beautiful and works well.
  4. If you want to predict more than one variable, it is still an issue (let me know if you have some ideas)

If you have time, I suggest you look at ‘Alex Graves’ approach (I would love to see an implementation for this MDN in pytorch).

Good luck!

18 Likes

One last thing (i think you can already guess by now): Good prediction result doesn’t mean at all good generation results. These are different objectives. Generating sequences is done by somehow ‘tricking’ the system that is trained on prediction.
In short: Good prediction != Good generation
Good prediction + following the state of the art practices (mostly, discretize) ~= Good generation --> ~= mean ‘hopefully equal’

3 Likes

Great post, thanks for all the references and ideas! :slight_smile:

I’ll have to look carefully at them but just as a first thought, looks like discretizing a time-series would represent data at each timestep as a one-hot vector (indicating the discrete bin where the real value falls). So instead of having a timeseries I would have a grid of T x R, where T is the set of timesteps and R is the number of the discrete steps that partition the range of my time series. Thoughts?

That looks good to me.
I recommend you take a quick look at Karpathy blog and the way he present the data to the network and how he samples from the network, it should clarify any ambiguities you’ve about this issue (it is a fun and interesting read)

1 Like

hardmaru reimplemented the MDN stuff from one of his earlier blogposts in PyTorch: https://github.com/hardmaru/pytorch_notebooks/blob/master/mixture_density_networks.ipynb

1 Like

How you have set up you have your only using the LSTM to changes from one data point to the next data point in that sequence. But doesn’t get any information about the sequences before it. Hence the cell state needs to be input for sequence to sequence in order to take full advantage of LSTM so that LSTM encompasses not just what happens from 1 step to the 20th step of data point in 20 datapoint sequence but encompasses all the datapoints from 1 to 30,000 whatever number of original datapoints you have. Hence the data can not be parallelized and cannot be shuffled as the output from each point is needed before training on next input and output.

if your training just 20 datapoints to one future point and only using LSTM for use on order of sequence from 1-20 and then not carrying cell state forward you might as well just use a regular mlp because you are getting no advantage using LSTM

I hope that helps :grin:

1 Like