I am training an RNN (LSTMs, more precisely) network with input video frames. The idea is to let a number of frames, for example, t=[0,8], through the LSTMs and use the last (t=8) hidden state to predict label associated with only t=4 (center t). I could’ve pre-processed the sequential data in a way where each t is padded with neighboring frames, but to be more efficient, I would like to design a training loop where I can “reuse” the previous computation results - for example, when I am training with t=[1,9], I would like to “reuse” the results of t=[1,8] from the previous forward pass, with only t=9 frame being added. (I understand that this is not exactly the same as the pre-processed way, since now the t=1 hidden state is not initialized, instead it is taking output from t=0, but this is not harmful in my opinion.) And before calling loss.backward(), I can just detach the t=0’s variables from the compute graph.
My main problem is how to detach these specific variables from the compute graph. I’ve tried to keep the pointers to those variables so later I could call detach() on those pointers to detach the actual variables. But that did not work. May I know if there is a way to do this?
Thanks in advance.