Stateful LSTM Pytorch


Where should I initialize the hidden state and cell state to have a LSTM stateful?


One way could be to add a wrapper nn.Module that contains the LSTM as a submodule and calls it with the hidden state.
Do you want the state to be carried over from the previous (forward) call or do you want the initial state to be learned (but the same for consecutive calls)? For the former you would need to avoid autograd desasters by detaching the state (or at least doing so often enough) and probably use a buffer (as in module.register_buffer), for the latter you could use a nn.Parameter.

Best regards


What I am doing is just initialize the states in the init method of the model class