Implementing Truncated Backpropagation Through Time

oh wow, perfectly what I was wondering, thanks!

On a testing sequence how can i run the algorithm? will i use the hidden state of the first step and then propagate it to the next step until the end of the whole sequence?
Also, in the case i have a lstm cell in the place of one_step_module. The cell gives as output (hn, cn) so what will happen with the cn ?
@albanD @gbarlas

@Thanasis_Klaras, did you ever find a successful strategy to apply this to LSTM? I’ve got the same question and am wondering if it makes sense to just re-hook both the hidden states (h & c), with two separate backward calls, in the secondary for loop, e.g.:

curr_h_grad = states[-j-1][0][0].grad
curr_c_grad = states[-j-1][0][1].grad                    
states[-j-2][1][0].backward(curr_h_grad, retain_graph=self.retain_graph)
states[-j-2][1][1].backward(curr_c_grad, retain_graph=self.retain_graph)                    

I made the same modification for my use case. Not sure if it is right

1 Like


Hi guys,

so first of all tank you very much for the neat code doing TBPTT. I am implementing a Levenberg Marquardt Optimizer for my Recurrent Neural Networks, so I need the Jacobian at each time step. Doing full BPTT every time step is just prohibitively expensive. So I ran the code on a very small example for only three time steps for which I calculated the gradient on paper. For the first two time steps, the gradients of pytorch match those I calculated, but in the third time step, when the inner for-loop is excercised fot the first time, the gradient in curr_grad somewhat accumulates, so in the end the result is off.

So my example is just a RNN with one state, no inputs, initial state x0=0.9, recurrent weight w=0.2
x1 = wx0 = 0.18
x1 = w^2 x0 = 0.036
x2 = w^3 *x0 = 0.0072

The gradients with respect to w are
dx1/dw = x0 = 0.9
dx2/dw = 2wx0 = 0.36
dx3/dw = 3w^2x0 = 0.108

The gradient in the last time step produced by TBPTT is dx3/dw=0.2880 instead.

My code to reproduce the example is

import torch
from torch.autograd import Variable

x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))

states = [(None,x0)]

for j in range(0,3):
    state = states[-1][1].detach()

    new_state = state*w


    for i in range(3-1):
        if states[-i-2][0] is None:

        curr_grad = states[-i-1][0].grad
        states[-i-2][1].backward(curr_grad, retain_graph=True)

It seems to me the gradient in states[-i-1][0] somehow accumulates. In the third time step (j=2) it should be states[-0-1][0].grad = 0.2 in the first iteration (i=0), which it is, and it should be states[-0-1][0].grad = 0.04 in the second iteration (i=1), but instead it is 0.24, the sum of both.

Could somebody point my mistake out to me or give me an idea how to fix this accumulation?

Best regards



If you net is shallow enough that the grad for one state can actually be the same as the one from another state, you can do curr_grad = states[-i-1][0].grad.clone() to make sure that your curr_grad won’t be modified by subsequent .backward() calls.

Thank you for the code! I tried implementing this solution for an LSTM architecture but I can’t get the gradient to backpropagate through the hidden states. I give more details about the problem in this question, any help would be really appreciated!

Hi there,

thank you very much for your answer. I replaced my line curr_grad = states[-i-1][0].grad with your suggestion curr_grad = states[-i-1][0].grad.clone(), unfortunately the result is still the same. I solved the problem with another, yet less elegant, approach which yields the desired gradient. The basic idea is to retain as many computational graphs as there are time steps that should be backpropagated (equals truncation depth k). This unfortunately involves calling the forward module multiple times (k times) in each iteration, which I wanted to avoid since it seems kind of brute force to me.

Here it goes:

x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))

depth = 3
for i in range(depth):

for j in range(0,3):
    for i in range(len(states)):
        states[i] = states[i]*w
    # Detach last state and append to list
    states = [states[0].clone().detach().requires_grad_(True)] + states  
    # Call backward on last y_i
    # delete last element in states
    del states[-1]

@ImTheLazyBoy This should work for you LSTM Module too, if you replace my very simple reccurence states[i] = states[i]*w with the one step forward function of your LSTM module.

Hi, thanks for the implementation. I got a question about the “retain_graph” parameter. When k2>k1, all the backward keeps the graph. So all the inter-results is kept in memory. If the sequence is long, this method will cause the out of memory error, right? Is there something I understand wrong? Or, does the line del states[0] get the function to empty the useless cache?

The retain_graph only says that it should not delete the graph as it goes through it.
But like any other object in python, it will be freed when you cannot access it anymore. In particular here, when you delete the corresponding state, then nothing points to the graph anymore and it is deleted.

Hi, in your code, there is only one layer. If I got multiple layers, I need to call tensor.backward(grad) several times. But several backward cost much time. Is there a better way to do this?

If all the layers are from a single step, you can store them as a single block and call a single backward for all of them.

What is “a single block”? Can I just the tensors to A, cat their grads to B and call A.backward(B)?

A single block corresponds to one_step_module() in the code sample above.
If you have multiple elements, you can use autograd.backward() to give multiple Tensors to backward at once (and it will accumulate all the gradients).

First of all thanks for the code. @albanD I want to apply TBPTT on part of my model, and I got a warning UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). when calling loss.backward(retain_graph=self.retain_graph). May I know if I could get around this? Thank you very much!


Where is this warning thrown, as the message suggest, it only appears because you try to access the .grad field of a Tensor whose .grad field won’t be populated. So you most likely don’t want to do that.

Hi, thanks for replying. So basically what I am doing is that, I have a network which is consist of two parts, supposed A and B. A produces a 2D list of LSTM’s hidden and output states tensors h and c, while B is some CNN that takes output from A as inputs and produces final prediction tensors. So essentially I was asking for gradients of the output of A, which is also where the warning is coming from.

@albanD Hi, so I managed to solve the issues I had previously by adding .retain_grad() to the middle variables in the forward pass of my model, and returning them alone with the final prediction. But now I am facing the

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16]] is at version 2; expected version 1 instead.

error caused by

retain_graph = True.

In my case, K2 = 9 and K1 = 1, because I want to back-prop though 9 time stamps when every 1 new frame has been added, while maintaining the total number of frames/states being 9. I pretty much followed the code structure that you provided, except for computing gradients individually for all the tensors in the nested lists that I have. May I know if you had any clue how to get around the above runtime error? Thank you very much!


You should enable anomaly mode to get more information about which Tensor and op are causing this issue.
This is most likely just that you do an inplace op of something that you shouldn’t :slight_smile:

Hi @albanD thanks for your help. I really appreciate it. After days of debugging it finally works! However, although I’ve fixed here and there in my code, the final step that made it work was to switch PyTorch from 1.7 back to 1.4. I did this after noticing a similar situation from this post Help me find inplace error, where the OP’s code did not work until he went back from 1.5 to 1.4. I am wondering why this would happen, does it suggest some serious change/bug between versions?

1 Like