I am using a LSTM in my model and my forward
function is defined as:
...
self.rnn = nn.LSTM(...)
def forward(self, input, hidden):
# dropout, linear embedding
emb = self.drop(self.encoder(input.contiguous().view(-1, self.num_in_features)))
emb = emb.view(-1, input.size(1), self.rnn_hid_size)
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
# [(seq_len x batch_size) * feature_size]
decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
# [ seq_len, batch_size, feature_size]
decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))
return decoded, hidden, output
I am doing some inference as follows:
model.eval()
hidden = model.init_hidden(1)
with torch.no_grad():
for i in range(end_point):
if i >= start_point:
out, hidden, _ = model.forward(out, hidden)
else:
out, hidden, _ = model.forward(my_input_seq[i].unsqueeze(0), hidden)
Now when I run this statement
out, hidden, _ = model.forward(out, hidden)
After I get the output, I want to undo this statement i.e. restore the LSTM state to before the call. I am guessing this would mean somehow undoing or restoring the hidden state to before the call. What would be a fast (and hopefully easy) way to achieve this in pytorch?