Train RNN on time series data keeping all time relationships

I want to train a RNN (LSTM) on time series data, doing backpropagation every step and keeping the time relationships from start to end.

I tried many things without success. If I set seq_len to the number of data points I have, the model trains quickly and keeps temporal relationships, but it only updates weights once per epoch.

The last thing I tried is setting seq_len and batch_size to 1, and passing the hidden states on every iteration:

hidden = None
for i in range(0, len(X)):
    single_tick = X[i].view(1, X[i].shape[0], X[i].shape[1])
    y_pred, hidden = net(single_tick, hidden)

    loss = criterion(y_pred, y[i])
    loss.backward(retain_graph=True)
    optimizer.step()

    train_loss_total += loss.data[0]

Note that I set retain_graph to True.

This kinda works, but it takes an very high amount of time every epoch, to the point where it’s almost unusable.

I want to know what is the common practice for this aparently simple task (keeping time relationships on long datasets). In Keras this works without having to do anything, so I imagine there is no technical limitation, but just a lack of knowledge from my part.

This is my model:

class Model(nn.Module):
def __init__(self, input_size, num_layers=2, hidden_size=256):
    super(Model, self).__init__()

    self.input_size = input_size
    self.num_layers = num_layers
    self.hidden_size = hidden_size

    self.lstm = nn.LSTM(self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=0.2, batch_first=True)
    self.dense = nn.Linear(self.hidden_size, 1)
    self.activation = nn.Sigmoid()

def forward(self, x, hidden=None):
    batch_size = x.shape[0]

    if hidden is None:
        h0 = Variable(torch.randn(self.num_layers, batch_size, self.hidden_size))
        c0 = Variable(torch.randn(self.num_layers, batch_size, self.hidden_size))
    else:
        (h0, c0) = hidden
    
    output, hidden = self.lstm(x, (h0, c0))
    output = output.view(batch_size, self.hidden_size)
    output = self.activation(self.dense(output))
    return output, hidden

A good idea, but the weights get updated at every timestep and as this involves backpropagating all the way back to the sequence start, it is gets pretty slow.

Might I suggest a compromise? Keep batch_size=1 and instead of setting seq_len to total_seq_len or to 1, try some value in between. E.g. if you set seq_len to total_seq_len/4, then the weights will be updated four times per epoch which will run a lot faster than once per timestep. Note that the fourth batch will run noticeably slower than the first because it has further to backpropagate.

I don’t think Keras does backpropagate right back to the sequence start unless you feed in the whole sequence at once. I am pretty sure it only ever backpropagates to the beginning of each batch.

I think the common approach is to set the seq_len to be relatively large, to call backward without the retain_graph option, and to hope for the best. Keeping the hidden state from one batch to the next will provide a summary of the previous history that the model can learn to use. So the model can learn some long term relationships, but it can’t learn them as finely as it might if it could backpropagate all the way through.

Another approach (which will be a little harder to implement) is described in this paper [1705.08209] Unbiasing Truncated Backpropagation Through Time Among other things it involves randomising the seq_len for each batch.

If you need to run in an online learning setting, i.e. updating the weights at every timestep in an efficient manner, then the following paper can help [1702.05043] Unbiased Online Recurrent Optimization