Stateful LSTM training

I am trying to train a stateful ConvLSTM module. Basically for each batch in mini-batch training, I want to save the hidden states for next batch to be used. If I do that without hidden_state.detach(), it throws an error. But, if I do with the detach function, the training goes on!

  1. My understanding is that we definitely need hidden_state in computing gradients wrt the loss. So why do we need to hidden_state.detach()? Won’t this detach the hidden state from gradient computation for this variable?
  2. Secondly, optimizer.zero_grad() should be called when I really want to free up the gradients and not in every batch forward pass as I need the accumulated gradients until the TBPTT is over.
  3. Should I call optimizer.step() every forward pass or when the TBPTT window is completed?

Hi Bidya, I am facing the exact same issue with my Network. Can you please tell me where exactly do you put the .detach()? The snippet of the code would be really helpful. Thanks