Copy and restore the hidden state of a RNN network

I am using a LSTM in my model and my forward function is defined as:

self.rnn = nn.LSTM(...)

def forward(self, input, hidden):
   #  dropout, linear embedding
    emb = self.drop(self.encoder(input.contiguous().view(-1, self.num_in_features)))  
    emb = emb.view(-1, input.size(1), self.rnn_hid_size)

    output, hidden = self.rnn(emb, hidden)
    output = self.drop(output)
    # [(seq_len x batch_size) * feature_size]
    decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
    # [ seq_len, batch_size, feature_size]
    decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))

    return decoded, hidden, output

I am doing some inference as follows:

hidden = model.init_hidden(1)

with torch.no_grad():
    for i in range(end_point):
        if i >= start_point:
            out, hidden, _ = model.forward(out, hidden)
            out, hidden, _ = model.forward(my_input_seq[i].unsqueeze(0), hidden)

Now when I run this statement

out, hidden, _ = model.forward(out, hidden)

After I get the output, I want to undo this statement i.e. restore the LSTM state to before the call. I am guessing this would mean somehow undoing or restoring the hidden state to before the call. What would be a fast (and hopefully easy) way to achieve this in pytorch?

Hi @Luca_Pamparana,

I hope I don’t get you wrong, but can’t you just omit the return values of the hidden states in that line:

And change it to

out, _, _ = model.forward(out, hidden)

Otherwise if you need the information of the hidden states at a different line in your code I would go for creating an extra variable. If you run into memory issues with that on gpu, you could pass this value to the cpu (maybe add an extra param within the models forward for memory allocation).

1 Like