Implementing Truncated Backpropagation Through Time

Yes you are right this is not good.
In this particular case it does not change anything but could have.
I modified the original post to use j for one of the two loops.

Here is an updated link

1 Like

this might be a stupid question, but did you need to store tuples, couldn’t you have just stored a list of old states?

Hi,

Yes you can, but since all of these are just references to the actual Tensors, it’s not a problem to store each of them multiple times.
So whichever your more comfortable with !

1 Like

this is a way to do back prop, but the original is not incorrect, it depends on the problem. If you have a label at the end of a sequence then your solution is not applicable, but if you are update the parameters for a decoder, this is useful, but one potential issue (that i’m currently trying to figure out) is that won’t your hidden state only propagate the error from the last label, since the states are being detached? The gradient accumulation will be on the last hidden state therefore when you perform back prop on previous labels those will only be factored into the label’s respective timestep. Or so it seems to me, or will the older state still get gradient accumulation from previous label errors ? I might be assuming the wrong behavior about detach()

1 Like

This post [Truncated backprop data clarification] may help you.

oh wow, perfectly what I was wondering, thanks!

On a testing sequence how can i run the algorithm? will i use the hidden state of the first step and then propagate it to the next step until the end of the whole sequence?
Also, in the case i have a lstm cell in the place of one_step_module. The cell gives as output (hn, cn) so what will happen with the cn ?
@albanD @gbarlas

@Thanasis_Klaras, did you ever find a successful strategy to apply this to LSTM? I’ve got the same question and am wondering if it makes sense to just re-hook both the hidden states (h & c), with two separate backward calls, in the secondary for loop, e.g.:

curr_h_grad = states[-j-1][0][0].grad
curr_c_grad = states[-j-1][0][1].grad                    
states[-j-2][1][0].backward(curr_h_grad, retain_graph=self.retain_graph)
states[-j-2][1][1].backward(curr_c_grad, retain_graph=self.retain_graph)                    
2 Likes

I made the same modification for my use case. Not sure if it is right

1 Like

@albanD

Hi guys,

so first of all tank you very much for the neat code doing TBPTT. I am implementing a Levenberg Marquardt Optimizer for my Recurrent Neural Networks, so I need the Jacobian at each time step. Doing full BPTT every time step is just prohibitively expensive. So I ran the code on a very small example for only three time steps for which I calculated the gradient on paper. For the first two time steps, the gradients of pytorch match those I calculated, but in the third time step, when the inner for-loop is excercised fot the first time, the gradient in curr_grad somewhat accumulates, so in the end the result is off.

So my example is just a RNN with one state, no inputs, initial state x0=0.9, recurrent weight w=0.2
x1 = wx0 = 0.18
x2=w
x1 = w^2 x0 = 0.036
x3=w
x2 = w^3 *x0 = 0.0072

The gradients with respect to w are
dx1/dw = x0 = 0.9
dx2/dw = 2wx0 = 0.36
dx3/dw = 3w^2x0 = 0.108

The gradient in the last time step produced by TBPTT is dx3/dw=0.2880 instead.

My code to reproduce the example is

import torch
from torch.autograd import Variable

x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))


states = [(None,x0)]

for j in range(0,3):
    
    state = states[-1][1].detach()
    state.requires_grad=True

    new_state = state*w
    
    states.append((state,new_state))

    
    new_state.backward(retain_graph=True)

    for i in range(3-1):
        if states[-i-2][0] is None:
            break

        curr_grad = states[-i-1][0].grad
        
        states[-i-2][1].backward(curr_grad, retain_graph=True)
    
    
    print(w.grad)
    w.grad.zero_()

It seems to me the gradient in states[-i-1][0] somehow accumulates. In the third time step (j=2) it should be states[-0-1][0].grad = 0.2 in the first iteration (i=0), which it is, and it should be states[-0-1][0].grad = 0.04 in the second iteration (i=1), but instead it is 0.24, the sum of both.

Could somebody point my mistake out to me or give me an idea how to fix this accumulation?

Best regards

Alexander

Hi,

If you net is shallow enough that the grad for one state can actually be the same as the one from another state, you can do curr_grad = states[-i-1][0].grad.clone() to make sure that your curr_grad won’t be modified by subsequent .backward() calls.

Thank you for the code! I tried implementing this solution for an LSTM architecture but I can’t get the gradient to backpropagate through the hidden states. I give more details about the problem in this question, any help would be really appreciated!

Hi there,

thank you very much for your answer. I replaced my line curr_grad = states[-i-1][0].grad with your suggestion curr_grad = states[-i-1][0].grad.clone(), unfortunately the result is still the same. I solved the problem with another, yet less elegant, approach which yields the desired gradient. The basic idea is to retain as many computational graphs as there are time steps that should be backpropagated (equals truncation depth k). This unfortunately involves calling the forward module multiple times (k times) in each iteration, which I wanted to avoid since it seems kind of brute force to me.

Here it goes:

x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))

depth = 3
states=[]
for i in range(depth):
    states.append(x0.clone())


for j in range(0,3):
    
    for i in range(len(states)):
    
        states[i] = states[i]*w
    
    
    # Detach last state and append to list
    states = [states[0].clone().detach().requires_grad_(True)] + states  
           
    # Call backward on last y_i
    states[-1].backward(retain_graph=False)    
           
    # delete last element in states
    del states[-1]
        
    print(w.grad)
    w.grad.zero_()

@ImTheLazyBoy This should work for you LSTM Module too, if you replace my very simple reccurence states[i] = states[i]*w with the one step forward function of your LSTM module.

Hi, thanks for the implementation. I got a question about the “retain_graph” parameter. When k2>k1, all the backward keeps the graph. So all the inter-results is kept in memory. If the sequence is long, this method will cause the out of memory error, right? Is there something I understand wrong? Or, does the line del states[0] get the function to empty the useless cache?

The retain_graph only says that it should not delete the graph as it goes through it.
But like any other object in python, it will be freed when you cannot access it anymore. In particular here, when you delete the corresponding state, then nothing points to the graph anymore and it is deleted.

Hi, in your code, there is only one layer. If I got multiple layers, I need to call tensor.backward(grad) several times. But several backward cost much time. Is there a better way to do this?

If all the layers are from a single step, you can store them as a single block and call a single backward for all of them.

What is “a single block”? Can I just torch.cat the tensors to A, cat their grads to B and call A.backward(B)?

A single block corresponds to one_step_module() in the code sample above.
If you have multiple elements, you can use autograd.backward() to give multiple Tensors to backward at once (and it will accumulate all the gradients).