Thank you! This fixed my problem, and the network is now training successfully. I am actually still using Sequential, but created a TakeFirst module that returns the first element of whatever is passed to it, as guided by this answer: Sequential LSTM II
One thing that is still unclear to me: how do you clear the LSTM’s memory? As I understand, loss.backward()
does this if you do not set retain_state
to True
. How do you manually clear the LSTM’s memory, during inference? Moreover, will the state be retained across inference runs?