There is a somewhat complicated implementation of truncated backpropagation through time here.
What I’d like to know is if it is possible to register a hook on state
such that when loss.backward() is called the hook can stop backpropagation after several steps through the history of state
. Or perhaps, can I write a wrapper for state such that it’s backward() method stops the backpropagation after a certain number of calls.