I’m trying to implement a time-series prediction rnn and for this I try to construct a stateful model.
Basically because I have a huge sequence I want to reuse states from previous batches instead of having them reset every time. Keras RNN class has a stateful parameter enabling exactly this behavior:
stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
I couldn’t find anything similar for pytorch and my attempts to make something like this manually failed so far.
Basically what I try to do is save variable tensors after forward pass and use them as initial state for h and c variables for LSTMCell on the subsequent forward calls.
I’m very new to pytorch so I’m probably doing something very wrong but so far I’m stuck.
I’d really appreciate any hint or any example of stateful RNN in pytorch.
A key point is that to keep hidden state Variable across batches without having to specify retain_graph=True you need to detach the Variable: hidden.detach().
to do something like this in pytorch you would just do something like:
class Model(torch.nn.Module):
def __init__(self, hidden_size):
super(Model, self).__init__()
self.lstm = nn.LSTMCell(1, hidden_size)
def forward(self, inputs): #and then in def forward:
x, (hx, cx) = inputs
x = x.view(x.size(0), -1)
hx, cx = self.lstm(x, (hx, cx))
x = hx
return x, (hx, cx)
flag=True #to flag when at start of time series
if flag: #beginning of sequence of data you want cell states from or when no longer need past cell state and starting fresh again
cx = Variable(torch.zeros(1, hidden_size))
hx = Variable(torch.zeros(1, hidden_size))
else: #get cell state from last sequence batch to use as start of new batch
cx = Variable(cx.data)
hx = Variable(hx.data)
#input equals Variable of 1 by num of features at each time step
model = Model(hidden_size)
output, (hx, cx) = model((input, (hx, cx))
else: #get cell state from last sequence batch to use as start of new batch
cx = Variable(cx.data)
hx = Variable(hx.data)
This serves same function as the hidden.detach() part stated above. I just prefer to create a fresh new Variable for each batch than reuse the same one