Is it actually possible to implement TBPTT with k2 > k1 in PyTorch?

Hi,

I know this post is about the same problem as this one, but I was thinking that starting on a new basis would perhaps refresh the interest and hopefully lead to a definitive answer to this question.

So very much like the author of the preceding post, I have a stateful network that I want to train on long sequences, and to dodge the memory limitations, I would like to use Truncated Backpropagation Through Time (TBPTT) with k2 > k1.

@albanD provided a neat answer, but his implementation does not work anymore on recent PyTorch versions. To make it work, you have to go back to 1.4, which I don’t want, because gradient calculations may be wrong.

As mentioned in the last comments to this day, by running his code in PyTorch latest versions (e.g., 1.8), we obtain the following error logs:

doing fw 0
doing fw 1
doing fw 2
doing fw 3
doing fw 4
doing backward 4
doing backward 3
doing backward 2
doing backward 1
doing backward 0
bw: 0.047006845474243164
doing fw 5
doing fw 6
doing fw 7
doing fw 8
doing fw 9
doing backward 9
doing backward 8
doing backward 7
doing backward 6
doing backward 5
doing backward 4
E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\autograd_init_.py:132: UserWarning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
File “BPTT.py”, line 88, in
runner.train(input_sequence, torch.zeros(200, layer_size))
File “BPTT.py”, line 28, in train
output, new_state = self.one_step_module(inp, state)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “BPTT.py”, line 65, in forward
full_out = self.lin(torch.cat([inp, state], 1))
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\module.py”, line 727, in call_impl
result = self.forward(*input, **kwargs)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\linear.py”, line 93, in forward
return F.linear(input, self.weight, self.bias)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\functional.py”, line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
(Triggered internally at …\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
allow_unreachable=True) # allow_unreachable flag
Traceback (most recent call last):
File “BPTT.py”, line 88, in
runner.train(input_sequence, torch.zeros(200, layer_size))
File “BPTT.py”, line 47, in train
states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\autograd_init.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 100]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

As @mpmisko hinted, “on a second backprop through a given hidden state, I am getting an in-place operation error”. To me, this inplace operation error occurs because the weights of the linear layer have a _version that is superior to their _version they had when you do former_state.backward(...) where here former_state is a state at step <= 4.
Indeed, weights were updated with the opimizer.step() in the meanwhile, so this is the so-called inplace operation (correct me if I’m wrong).

So my question is, -as it is impossible to edit the graph and the _version attribute of weight tensors (for instance):
Is it actually possible to implement TBPTT with k2 > k1 in PyTorch ?

Although such a case of TBPTT seems conceptually simple, I’ve been searching for a while, and I haven’t seen a single working implementation !(!!) Also I am far from being an expert in how PyTorch computes the gradients, so that’s why I’m asking for your help ! Sorry for the long post, and thank you in advance for your insights !

1 Like

Hi, I met the same issue. Have you solved this issue?