I am trying to train a stateful ConvLSTM module. Basically for each batch in mini-batch training, I want to save the hidden states for next batch to be used. If I do that without hidden_state.detach()
, it throws an error. But, if I do with the detach function, the training goes on!
- My understanding is that we definitely need hidden_state in computing gradients wrt the loss. So why do we need to
hidden_state.detach()
? Won’t this detach the hidden state from gradient computation for this variable? - Secondly,
optimizer.zero_grad()
should be called when I really want to free up the gradients and not in every batch forward pass as I need the accumulated gradients until the TBPTT is over. - Should I call
optimizer.step()
every forward pass or when the TBPTT window is completed?