# Implementing Truncated Backpropagation Through Time

oh wow, perfectly what I was wondering, thanks!

On a testing sequence how can i run the algorithm? will i use the hidden state of the first step and then propagate it to the next step until the end of the whole sequence?
Also, in the case i have a lstm cell in the place of one_step_module. The cell gives as output (hn, cn) so what will happen with the cn ?
@albanD @gbarlas

@Thanasis_Klaras, did you ever find a successful strategy to apply this to LSTM? I’ve got the same question and am wondering if it makes sense to just re-hook both the hidden states (h & c), with two separate backward calls, in the secondary for loop, e.g.:

``````curr_h_grad = states[-j-1].grad
``````
2 Likes

I made the same modification for my use case. Not sure if it is right

1 Like

@albanD

Hi guys,

so first of all tank you very much for the neat code doing TBPTT. I am implementing a Levenberg Marquardt Optimizer for my Recurrent Neural Networks, so I need the Jacobian at each time step. Doing full BPTT every time step is just prohibitively expensive. So I ran the code on a very small example for only three time steps for which I calculated the gradient on paper. For the first two time steps, the gradients of pytorch match those I calculated, but in the third time step, when the inner for-loop is excercised fot the first time, the gradient in curr_grad somewhat accumulates, so in the end the result is off.

So my example is just a RNN with one state, no inputs, initial state x0=0.9, recurrent weight w=0.2
x1 = wx0 = 0.18
x2=w
x1 = w^2 x0 = 0.036
x3=w
x2 = w^3 *x0 = 0.0072

The gradients with respect to w are
dx1/dw = x0 = 0.9
dx2/dw = 2wx0 = 0.36
dx3/dw = 3w^2x0 = 0.108

The gradient in the last time step produced by TBPTT is dx3/dw=0.2880 instead.

My code to reproduce the example is

``````import torch

w = torch.nn.parameter.Parameter(torch.tensor(0.2))

states = [(None,x0)]

for j in range(0,3):

state = states[-1].detach()

new_state = state*w

states.append((state,new_state))

new_state.backward(retain_graph=True)

for i in range(3-1):
if states[-i-2] is None:
break

``````

It seems to me the gradient in states[-i-1] somehow accumulates. In the third time step (j=2) it should be states[-0-1].grad = 0.2 in the first iteration (i=0), which it is, and it should be states[-0-1].grad = 0.04 in the second iteration (i=1), but instead it is 0.24, the sum of both.

Could somebody point my mistake out to me or give me an idea how to fix this accumulation?

Best regards

Alexander

Hi,

If you net is shallow enough that the grad for one state can actually be the same as the one from another state, you can do `curr_grad = states[-i-1].grad.clone()` to make sure that your `curr_grad` won’t be modified by subsequent `.backward()` calls.

Thank you for the code! I tried implementing this solution for an LSTM architecture but I can’t get the gradient to backpropagate through the hidden states. I give more details about the problem in this question, any help would be really appreciated!

Hi there,

thank you very much for your answer. I replaced my line `curr_grad = states[-i-1].grad` with your suggestion `curr_grad = states[-i-1].grad.clone()`, unfortunately the result is still the same. I solved the problem with another, yet less elegant, approach which yields the desired gradient. The basic idea is to retain as many computational graphs as there are time steps that should be backpropagated (equals truncation depth k). This unfortunately involves calling the forward module multiple times (k times) in each iteration, which I wanted to avoid since it seems kind of brute force to me.

Here it goes:

``````x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))

depth = 3
states=[]
for i in range(depth):
states.append(x0.clone())

for j in range(0,3):

for i in range(len(states)):

states[i] = states[i]*w

# Detach last state and append to list

# Call backward on last y_i
states[-1].backward(retain_graph=False)

# delete last element in states
del states[-1]

``````

@ImTheLazyBoy This should work for you LSTM Module too, if you replace my very simple reccurence `states[i] = states[i]*w` with the one step forward function of your LSTM module.

Hi, thanks for the implementation. I got a question about the “retain_graph” parameter. When k2>k1, all the backward keeps the graph. So all the inter-results is kept in memory. If the sequence is long, this method will cause the out of memory error, right? Is there something I understand wrong? Or, does the line `del states` get the function to empty the useless cache?

The `retain_graph` only says that it should not delete the graph as it goes through it.
But like any other object in python, it will be freed when you cannot access it anymore. In particular here, when you delete the corresponding state, then nothing points to the graph anymore and it is deleted.

Hi, in your code, there is only one layer. If I got multiple layers, I need to call `tensor.backward(grad)` several times. But several backward cost much time. Is there a better way to do this?

If all the layers are from a single step, you can store them as a single block and call a single backward for all of them.

What is “a single block”? Can I just `torch.cat` the tensors to `A`, `cat` their grads to `B` and call `A.backward(B)`?

A single block corresponds to `one_step_module()` in the code sample above.
If you have multiple elements, you can use `autograd.backward()` to give multiple Tensors to backward at once (and it will accumulate all the gradients).

First of all thanks for the code. @albanD I want to apply TBPTT on part of my model, and I got a warning `UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward().` when calling `loss.backward(retain_graph=self.retain_graph)`. May I know if I could get around this? Thank you very much!

Hi,

Where is this warning thrown, as the message suggest, it only appears because you try to access the `.grad` field of a Tensor whose `.grad` field won’t be populated. So you most likely don’t want to do that.

Hi, thanks for replying. So basically what I am doing is that, I have a network which is consist of two parts, supposed A and B. A produces a 2D list of LSTM’s hidden and output states tensors `h` and `c`, while B is some CNN that takes output from A as inputs and produces final prediction tensors. So essentially I was asking for gradients of the output of A, which is also where the warning is coming from.

@albanD Hi, so I managed to solve the issues I had previously by adding `.retain_grad()` to the middle variables in the forward pass of my model, and returning them alone with the final prediction. But now I am facing the

`RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor ] is at version 2; expected version 1 instead.`

error caused by

`retain_graph = True`.

In my case, `K2 = 9` and `K1 = 1`, because I want to back-prop though 9 time stamps when every 1 new frame has been added, while maintaining the total number of frames/states being 9. I pretty much followed the code structure that you provided, except for computing gradients individually for all the tensors in the nested lists that I have. May I know if you had any clue how to get around the above runtime error? Thank you very much!

Hi,

You should enable anomaly mode to get more information about which Tensor and op are causing this issue.
This is most likely just that you do an inplace op of something that you shouldn’t Hi @albanD thanks for your help. I really appreciate it. After days of debugging it finally works! However, although I’ve fixed here and there in my code, the final step that made it work was to switch PyTorch from 1.7 back to 1.4. I did this after noticing a similar situation from this post Help me find inplace error, where the OP’s code did not work until he went back from 1.5 to 1.4. I am wondering why this would happen, does it suggest some serious change/bug between versions?

1 Like