Trying to backward through the graph second time, but the buffers have already been freed

I am trying to implement recurrent weighted average (RWA) in pytorch code.

I just changed the RNN class to RWA from the practical pytorch’s char-rnn-classification example.

But it is giving me the error below:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in ()
20 for epoch in range(1, n_epochs + 1):
21 category, line, category_tensor, line_tensor = training_pair()
—> 22 output, loss = train(category_tensor, line_tensor)
23 current_loss += loss
24

<ipython-input-214-4bee22e367f7> in train(categroy_tensor, line_tensor)
      7     loss = criterion(output, category_tensor)
      8     print("loss:" , loss)
----> 9     loss.backward()
     10 
     11     for p in rnn.parameters():

/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py in backward(self, gradient, retain_variables)
    143                     'or with gradient w.r.t. the variable')
    144             gradient = self.data.new().resize_as_(self.data).fill_(1)
--> 145         self._execution_engine.run_backward((self,), (gradient,), retain_variables)
    146 
    147     def register_hook(self, hook):

/usr/local/lib/python3.5/dist-packages/torch/autograd/_functions/basic_ops.py in backward(self, grad_output)
     37 
     38     def backward(self, grad_output):
---> 39         a, b = self.saved_tensors
     40         return grad_output.mul(b), grad_output.mul(a)
     41 

RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. Please specify retain_variables=True when calling backward for the first time.

The code I changed looks like below:

import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class RWA(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RWA, self).__init__()
        
        self.max_steps = 1
        self.batch_size = 1
        self.hidden_size = hidden_size
        
        self.n = Variable(torch.Tensor(self.batch_size, hidden_size), requires_grad=True)
        self.d = Variable(torch.Tensor(self.batch_size, hidden_size), requires_grad=True)
        
        self.x2u = nn.Linear(input_size, hidden_size)
        self.c2g = nn.Linear(input_size + hidden_size, hidden_size)
        self.c2q = nn.Linear(input_size + hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
    
    def forward(self, input, hidden):
        h = F.tanh(hidden)
        
        for i in range(len(input)):
            combined = torch.cat((input[i], h), 1)
            
            
            u = self.x2u(input[i])
            g = self.c2g(combined)
            q = self.c2q(combined)
            q_greater = F.relu(q)
            scale = torch.exp(-q_greater)
            a_scale = torch.exp(q-q_greater)
            self.n = (self.n * scale) + ((u * F.tanh(g)) * a_scale)
            self.d = (self.d * scale) + a_scale
            h = F.tanh(torch.div(self.n, self.d))
        output = self.out(h)
        return output, h

    def init_hidden(self):
        return Variable(torch.randn(1, self.hidden_size))

n_hidden = 128
rwa = RWA(n_letters, n_hidden, n_categories)
print("n_letters:", n_letters, "n_hidden:", n_hidden, "n_categories:", n_categories)
print(rwa)
n_letters: 57 n_hidden: 128 n_categories: 18
RNN (
  (x2u): Linear (57 -> 128)
  (c2g): Linear (185 -> 128)
  (c2q): Linear (185 -> 128)
  (out): Linear (128 -> 18)
)

I tried the solution to a similar problem but it does not work. It should be working because i am re-initializing initial hidden state at every iterations.

def train (categroy_tensor, line_tensor):
    hidden = rwa.init_hidden()
    rwa.zero_grad()
    output, hidden = rwa(line_tensor, hidden)
    loss = criterion(output, category_tensor)
    print("loss:" , loss)
    loss.backward()
    
    for p in rwa.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    return output, loss.data[0]

I suspect the self.n and self.d. So I tried chaning them to nn.Parameter, but now it complains when I try to assign a value here: self.n = (self.n * scale) + ((u * F.tanh(g)) * a_scale). What should I do to solve this problem?

1 Like

Author of the RWA model here. I saw your post, and I wanted to let you know a flaw has been discovered in my code. The flaw deals with the numerical stability of the RWA model. If left uncorrected it prevents the model from forming long-term memories. Once you fix your code you may discover the issue you are having go away. Maybe so, maybe not.

I just posted the corrected code on my repo (https://github.com/jostmey/rwa). I owe a special thanks to some dude named Alex Nichol (a.k.a. unixpickle) for finding the bug. I am re-running all my results. So far the corrected model appears to run at least as well as before :slight_smile:

1 Like

If self.n and self.d are initial states which you want to optimize as additional parameters, then they should be nn.Parameter instances and shouldn’t be assigned to during forward. If they’re just temporary variables to hold state during the forward pass, they should not be attributes of self and should be created and assigned to as ordinary local variables during forward.

2 Likes