Stateful RNN example

Hi there,

I’m trying to implement a time-series prediction rnn and for this I try to construct a stateful model.
Basically because I have a huge sequence I want to reuse states from previous batches instead of having them reset every time. Keras RNN class has a stateful parameter enabling exactly this behavior:

stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.

I couldn’t find anything similar for pytorch and my attempts to make something like this manually failed so far.
Basically what I try to do is save variable tensors after forward pass and use them as initial state for h and c variables for LSTMCell on the subsequent forward calls.

I’m very new to pytorch so I’m probably doing something very wrong but so far I’m stuck.
I’d really appreciate any hint or any example of stateful RNN in pytorch.


This thread may be relevant: [Solved] Training a simple RNN

A key point is that to keep hidden state Variable across batches without having to specify retain_graph=True you need to detach the Variable: hidden.detach().


to do something like this in pytorch you would just do something like:

class Model(torch.nn.Module):
    def __init__(self, hidden_size):
        super(Model, self).__init__()
        self.lstm = nn.LSTMCell(1, hidden_size)
    def forward(self, inputs):  #and then in def forward:
        x, (hx, cx) = inputs
        x = x.view(x.size(0), -1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx
        return x, (hx, cx)

flag=True #to flag when at start of time series

if flag:   #beginning of sequence of data you want cell states from or when no longer need past cell state and starting fresh again
    cx = Variable(torch.zeros(1, hidden_size))
    hx = Variable(torch.zeros(1, hidden_size))
else:     #get cell state from last sequence batch to use as start of new batch
    cx = Variable(
    hx = Variable(

#input equals Variable of 1 by num of features at each time step

model = Model(hidden_size)
output, (hx, cx) = model((input, (hx, cx))

Thank you! This is very helpful.

glad to hear it :+1:

oh fyi this part:

else:     #get cell state from last sequence batch to use as start of new batch
    cx = Variable(
    hx = Variable(

This serves same function as the hidden.detach() part stated above. I just prefer to create a fresh new Variable for each batch than reuse the same one

So here’s the model I’ve been able to make:

#!/usr/bin/env python

import torch
import torch.nn as nn
import torch.optim as opt
import numpy as np

from torch.autograd import Variable
import pandas as pd


num_epochs = 1
hidden_size = 100

def var(tensor):
    if torch.cuda.is_available():
        return Variable(tensor.cuda())
        return Variable(tensor)

class Model(nn.Module):
    def __init__(self, input_size=1, hidden_size=hidden_size, output_size=1,
        super(Model, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers)
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        lstm_out, hidden = self.lstm(x, hidden)
        linear_out = self.linear(lstm_out)
        relu_out = self.relu(linear_out)
        out = self.out(relu_out)

        return out, hidden

    def init_hidden(self):
        return (var(torch.zeros(self.num_layers, 1, self.hidden_size)),
                var(torch.zeros(self.num_layers, 1, self.hidden_size)))

def load_data(n):
    ix = np.arange(n)
    data = np.sin(2*np.pi*ix/float(n/2))

    return (torch.from_numpy(data[:-1]).float(),

def train(ds):
    model = Model()
    # criterion = nn.MSELoss()
    criterion = nn.L1Loss()
    optimizer = opt.Adam(model.parameters(), lr=0.0002)

    if torch.cuda.is_available():
        print("Using GPU")

    lsum = 0.0
    lcount = 0.0

    hidden = model.init_hidden()

    losses = []

    for epoch in range(num_epochs):
        for i, (_x, _y) in enumerate(zip(ds[0], ds[1])):
            x = torch.zeros(1, 1, 1)
            y = torch.zeros(1, 1, 1)

            x[0][0][0] = _x
            y[0][0][0] = _y

            out, hidden = model(var(x), hidden)
            loss = criterion(out, var(y))
            hidden = (var(hidden[0].data), var(hidden[1].data))

            lsum +=[0]
            lcount += 1


            if i % 1000 == 0:
                l = lsum / lcount
                print('Epoch: %i, Loss: %.7f' % (epoch, l))
                lsum = 0.0
                lcount = 0.0

if __name__ == "__main__":

It does not seem to be able to learn even the simple sine function, here’s the plot of losses:

The loss just seems to oscillate along with the input.
I’d really appreciate any hint on what is it I’m doing wrong.

Won’t detaching the hidden Variable cut off BPTT?

Yes, but you’d do it between minibatches, like Variable( as in dgriff’s example above and serves the same purpose.

Hi, have you been able to figure out the reason why it did not work?