Implementing Truncated Backpropagation Through Time

Hi,

There are serious changes. And in particular hardening of our tests for inplace ops.
So if a newer version complains while an older works fine it most likely means that the old version was silently giving wrong gradients and it was fixed in the newer version.
So I would not recommend downgrading your pytorch version as a “solution”.

1 Like

Hi,
Unfortunately running the code on Pytorch 1.6 and 1.71 fails:

doing fw 0
doing fw 1
doing fw 2
doing fw 3
doing fw 4
doing backward 4
doing backward 3
doing backward 2
doing backward 1
doing backward 0
bw: 0.047006845474243164
doing fw 5
doing fw 6
doing fw 7
doing fw 8
doing fw 9
doing backward 9
doing backward 8
doing backward 7
doing backward 6
doing backward 5
doing backward 4
E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\autograd_init_.py:132: UserWarning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
File “BPTT.py”, line 88, in
runner.train(input_sequence, torch.zeros(200, layer_size))
File “BPTT.py”, line 28, in train
output, new_state = self.one_step_module(inp, state)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “BPTT.py”, line 65, in forward
full_out = self.lin(torch.cat([inp, state], 1))
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\module.py”, line 727, in call_impl
result = self.forward(*input, **kwargs)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\modules\linear.py”, line 93, in forward
return F.linear(input, self.weight, self.bias)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\nn\functional.py”, line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
(Triggered internally at …\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
allow_unreachable=True) # allow_unreachable flag
Traceback (most recent call last):
File “BPTT.py”, line 88, in
runner.train(input_sequence, torch.zeros(200, layer_size))
File “BPTT.py”, line 47, in train
states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
File “E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "E:\progs2\Anaconda3\envs\pytorch17\lib\site-packages\torch\autograd_init
.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 100]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

1 Like

Hey! I am having the same issue as @slaweks17 above. Essentially, on a second backprop through a given hidden state, I am getting an in-place operation error. I have double checked my code, and there are no in-place operations in the forward pass, which is supported by the fact that the first backprop works. Do you have any ideas what may be the issue?

2 Likes

Hi Alban, thanks for the code. But I’m facing the same problem as @mpmisko and @slaweks17 using Pytorch 1.8.1.
Any help on that would be really appreciated!

I’ve seen on stackoverflow, that according to this answer moving the optimizer outside the scope of the for loop seems to work fine.
Is this a valid approach?

Yes that would do the trick :slight_smile:

1 Like

Thanks for the fix, but can you clarify how you moved the optimizer outside the scope of the for-loop?

HI fynexx, I just moved the optimizer.step() two tabs to the left, so that it is outside the for loop. It looks like this:

            print("bw: {}".format(time.time()-start))
    optimizer.step()

I have to say, I’m not strong in programming. It just doesn’t throw the error now. I hope I didn’t mess up the TPBTT by doing this. I didn’t check exhaustively if the code does what it is supposed to do.

1 Like

I am bit confused in the following lines of your code:

                    curr_grad = states[-j-1][0].grad
                    states[-j-2][1].backward(curr_grad, retain_graph=self.retain_graph)

Even though my k1=k2=10, it throws a runtime error complaining about trying to backward through the graph a second time. To me this error makes sense as we are calling backward in loop. Maybe it helps if someone can explain why when calling loss.backward we do need to retain graph and we don’t do the same for states (isn’t backpropagation through states handled internally as when we do normal BPTT?).

I have adapted this to make it work for an LSTM (of sequence length = 1).

I change these lines

            state = states[-1][1].detach()
            state.requires_grad=True
            output, new_state = self.one_step_module(inp, state)

by these lines:

                h = states[-1][1][0].detach()
                c = states[-1][1][1].detach()
                h.requires_grad = True
                c.requires_grad = True
                state = (h,c)
                output, new_state = self.model(inp, state)

and in the optimization loop I do:

            if (t+1)%self.k1 == 0:
                self.optim_f.zero_grad()
                start = time.time()
                for j in range(self.k2-1): 

                    if j < self.k1:
                        loss = self.l2_fn(outputs[-j-1], targets[-j-1]).mean()
                        loss.backward(retain_graph=True)

                    # if we get all the way back to the init state, stop
                    if states[-j-2][0] is None:
                        break

                    curr_grad_h = states[-j-1][0][0].grad
                    curr_grad_c = states[-j-1][0][1].grad
                    states[-j-2][1][0].backward(curr_grad_h, retain_graph=self.retain_graph) #h0
                    states[-j-2][1][1].backward(curr_grad_c, retain_graph=self.retain_graph) #c0

                print('bw: {}'.format(time.time()-start))
                self.optim_f.step()

How do I make it work with retain grap = False (k1 == k2)?
Thanks!

Those with in-place operation problems @g_hansen
The problem:
Calling the backward step for the first time on the second last hidden state, after one optimization step has already taken place, fails due to an anomalous in-place operation that occurred before somewhere.

The actual cause:
When the optimizer step is run for the first time, weights are updated in-place in the respective hidden units or Conv filters, as per the PyTorch implementation. The problematic flow is:

start hidden states (None, H0) forward() loss, hidden states (H1) loss.backward() optimizer.step() (behind the scenes in-place update of weights) forward() loss, hidden states (H2) loss.backward() H1.backward() (pytorch finds unexpected weights different from H1’s computation graph).

@albanD Do you agree and is there a way in PyTorch to prevent in-place weight updates?

Found this: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation? · Issue #39141 · pytorch/pytorch · GitHub
As suggested by @albanD in the above ticket (and further up in this post), the solution is to call the optimizer step once you’ve gone through your whole sequence and thus calculated gradients for all the timesteps.

This is really cool because now you’re practically doing the optimization step similar to BPTT where you pass in the whole sequence in one mini-batch and then run optimization. This also means I am going to have to train for days because I’ll be doing one optimization step every 1.5 hours :sweat_smile:

Hey

If it is really inconvenient for you to delay the step function, you can always re-do the forward after updating the weights :slight_smile:

1 Like

Great solution! The tradeoff of additional space complexity caused by storing past k2 images per iteration is definitely worth the increased optimization steps.
Thanks!

Hi @uzborg95 and @albanD, thanks for sharing this. Could you clarify how you “store the past k2 images per iteration” to avoid delaying the step function. As @albanD suggested, we need to re-do forward after each step, but that will increase the time complexity. Is there a way to avoid this?

Hello @uzborg95, could you please share the implementation? I am quite confused on what re-doing the forward after updating weights actually means… Thanks a lot in advance

I have the same problem.

The problem is that the previous iteration of j will store gradients in states[i][0].grad. This entry needs to be set to zero first before calculating gradients in the new iteration of j. In my case the results are correct after this change:

x0 = torch.tensor(0.9,requires_grad=True)
w = torch.nn.parameter.Parameter(torch.tensor(0.2))

states = [(None,x0)]

for j in range(0,3):
    
    state = states[-1][1].detach()
    state.requires_grad=True

    new_state = state*w
    
    states.append((state,new_state))
    
   ## Newly added part: the gradients of states have to be cleared first.
    for i in range(1, j+2):
        if states[i][0].grad is not None:
            states[i][0].grad.zero_()
        
    new_state.backward(retain_graph=True)
            
    for i in range(3-1):
        if states[-i-2][0] is None:
            break

        curr_grad = states[-i-1][0].grad.clone()
        states[-i-2][1].backward(curr_grad, retain_graph=True)
    
    print('w=', w.grad.detach().numpy())
    w.grad.zero_()

The following code works for some new pytorch versions. The in-place ops are now performed on the ‘data’ field of the parameters.

class TBPTT():
    def __init__(self, one_step_module, loss_module, k1, k2, optimizer):
        self.one_step_module = one_step_module
        self.loss_module = loss_module
        self.k1 = k1
        self.k2 = k2
        self.retain_graph = k1 < k2
        # You can also remove all the optimizer code here, and the
        # train function will just accumulate all the gradients in
        # one_step_module parameters
        self.optimizer = optimizer

    def train(self, input_sequence, init_state):
        states = [(None, init_state)]
        for j, (inp, target) in enumerate(input_sequence):

            state = states[-1][1].detach()
            state.requires_grad=True
            output, new_state = self.one_step_module(inp, state)
            states.append((state, new_state))

            while len(states) > self.k2:
                # Delete stuff that is too old
                del states[0]

            if (j+1)%self.k1 == 0:
                loss = self.loss_module(output, target)

                optimizer.zero_grad()
                # backprop last module (keep graph only if they ever overlap)
                start = time.time()
                loss.backward(retain_graph=self.retain_graph)
                for i in range(self.k2-1):
                    # if we get all the way back to the "init_state", stop
                    if states[-i-2][0] is None:
                        break
                    curr_grad = states[-i-1][0].grad
                    states[-i-2][1].backward(curr_grad, retain_graph=self.retain_graph)
                print("bw: {}".format(time.time()-start))
                optimizer.step()



seq_len = 20
layer_size = 50

idx = 0

class MyMod(nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.lin = nn.Linear(2*layer_size, 2*layer_size)
        self.param_data = []
        def trans_grad(source, other):
            def trans_fn(grad):
                other.grad = grad + 0 if source.grad is None else source.grad
            return trans_fn

        for p in self.parameters():
            d = p.data.requires_grad_()
            self.param_data.append(d)
            p.register_hook(trans_grad(p, d))

    def forward(self, inp, state):
        global idx
        full_out = self.lin(torch.cat([inp, state], 1))
        # out, new_state = full_out.chunk(2, dim=1)
        out = full_out.narrow(1, 0, layer_size)
        new_state = full_out.narrow(1, layer_size, layer_size)
        def get_pr(idx_val):
            def pr(*args):
                print("doing backward {}".format(idx_val))
            return pr
        new_state.register_hook(get_pr(idx))
        out.register_hook(get_pr(idx))
        print("doing fw {}".format(idx))
        idx += 1
        return out, new_state


one_step_module = MyMod()
loss_module = nn.MSELoss()
input_sequence = [(torch.rand(200, layer_size), torch.rand(200, layer_size))] * seq_len

optimizer = torch.optim.SGD(one_step_module.param_data, lr=1e-3)

runner = TBPTT(one_step_module, loss_module, 5, 7, optimizer)

runner.train(input_sequence, torch.zeros(200, layer_size))
print("done")