Trouble avoiding "Trying to backward through the graph a second time" error on simple model

Hi,

I’m trying to teach myself the nuts and bolts of pytorch by implementing a relatively low level model (a kalman filter).

I keep running into the error: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I know this is a common error and there are a lot of posts with responses asking about it, but I haven’t gotten any solution to work. This is may be because I can’t see how the situation applies to my model, which makes me concerned I don’t understand the nuts and bolts very well.

I’ve tried to simplify the example as much as possible:

import torch
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor

N = 50
vel_true = torch.cumsum(torch.randn(N)/10., 0) # true velocity
pos_true = torch.cumsum(vel_true, 0) # true position
pos_obs = pos_true + torch.randn(N) # observed position

class KalmanFilter1(torch.nn.Module):
    def __init__(self):
        super(KalmanFilter1, self).__init__()
        self.log_obs_std = Parameter(Tensor([ 0.]))
        
        # if this parameter is set like so, the error doesn't occur:
        # self.log_process_std = Parameter(Tensor([-1.])) 

        # if the parameter is defined as its ratio to the other parameter, the error occurs.
        # no real reason to do this, but it was the easiest way to make the error occur in
        # this simple example (I originally encountered the error in a multivariate kalman-filter
        # but I think that code would just make the problem more obscure)
        self.log_ratio = Parameter(Tensor([ 0.]))
        self.log_process_std = torch.log( torch.exp(self.log_obs_std) * torch.exp(self.log_ratio) )
        
    def predict(self, x, k_mean, k_var):
        k_mean = k_mean
        k_var = k_var + torch.pow(torch.exp(self.log_process_std),2)
        return k_mean, k_var
    
    def update(self, x, k_mean, k_var):
        resid = x - k_mean
        K = k_var / ( k_var + torch.pow(torch.exp(self.log_obs_std),2) )
        k_mean = k_mean + K*resid
        k_var = (1.-K) * k_var
        return k_mean, k_var
        
    def forward(self, x):
        output = []
        
        k_mean = Variable(torch.zeros(1))
        k_var = Variable(torch.ones(1))
        output.append(k_mean)
        for i in range(len(x)-1):
            k_mean, k_var = self.predict(x[i], k_mean, k_var)
            k_mean, k_var = self.update(x[i], k_mean, k_var)
            output.append(k_mean)

        output = torch.cat(output, 0)
        return output

model = KalmanFilter1()
optimizer = torch.optim.RMSprop(model.parameters(),lr=0.01)

for t in range(100):
    obs_x = Variable(pos_obs)
    pred = model(obs_x)
    loss = torch.sum(torch.pow(obs_x - pred, 2))
    optimizer.zero_grad()
    loss.backward(retain_graph=False)
    optimizer.step()

As far as I can understand what retaining the graph means, I don’t think it should be needed here, as each training iteration is independent – I think I must be specifying something incorrectly about my model. Any help conceptually understanding why this error is occurring would be appreciated.

PyTorch uses dynamic graph. It’s not symbolic expression as tf is. You are backproping via this graph multiple times. Manually compute at each iteration please.

3 Likes

Hello,
I faced this error once. Here is the solution: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

In my case, since I called .backward() twice, It happend to me. I fixed this by adding retain_graph = True as an input of the first backward call

This was the trick! I had exactly the conceptual confusion that you’ve identified. It was also the source of my problems in my original code.

Thanks!

1 Like

Hey @jwdink, what was the change you made in your code to get it to work finally? I have a similar optimization process I am trying to solve and not sure how to work around this error.