No grad accumulator for a saved leaf Error

Facing this error through a FNN with embedding model

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-53-0b3e51f98dc8> in <module>()
     17 
     18         # Getting gradients w.r.t. parameters
---> 19         loss.backward()
     20 
     21         # Updating parameters

~/miniconda3/envs/amn/lib/python3.5/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
     91                 products. Defaults to ``False``.
     92         """
---> 93         torch.autograd.backward(self, gradient, retain_graph, create_graph)
     94 
     95     def register_hook(self, hook):

~/miniconda3/envs/amn/lib/python3.5/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     87     Variable._execution_engine.run_backward(
     88         tensors, grad_tensors, retain_graph, create_graph,
---> 89         allow_unreachable=True)  # allow_unreachable flag
     90 
     91 

RuntimeError: No grad accumulator for a saved leaf!

Any idea?

class FeedforwardNeuralNetModel(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super(FeedforwardNeuralNetModel, self).__init__()
        # Embedding layer
        self.embedding = nn.Embedding(input_dim, embedding_dim)

        self.fc1 = nn.Linear(embedding_dim*embedding_dim, hidden_dim) 

        self.sigmoid = nn.Sigmoid()

        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        self.sigmoid_out = nn.Sigmoid()

    def forward(self, x):
        # Embedding
        embedded = self.embedding(x)
        embedded = embedded.view(-1, embedding_dim*embedding_dim)
        out = self.fc1(embedded)

        out = self.sigmoid(out)

        out = self.fc2(out)
        
        out = self.sigmoid_out(out)
    
        return out

for epoch in range(num_epochs):
    for i, (samples, labels) in enumerate(train_loader):
        # Load samples
        samples = samples.view(-1, max_len).requires_grad_()
        labels = labels.view(-1, 1)

        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = model(samples)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)

        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()

Does changing the sample creation to samples = samples.view(-1, max_len).clone().requires_grad_() solves the problem? (notice the extra .clone() before requires_grad_()).

Thanks for the reply @albanD . Still doesn’t solve it. Really weird, can’t figure out. Seems like a bug.

I’m not completely sure, but based on the code snippet it seems samples should contain indices, as it’s passed into an nn.Embedding layer in the model.
If that’s the case, samples should be of dtype=torch.long.
Again, if that’s the case, samples should not be able to require gradients.
At least I’m not aware of a method to get valid gradients for integer values.
But I must say, I’m also not really an expert in NLP tasks, so maybe you can correct me and we can figure out, what’s going on. :wink:

Thanks for your comments!

Got what you’re driving at, let me think about it and debug further.

1 Like

Looking into it, I removed .requires_grad_() from samples , seems to get rid of that error. Will dig deeper to see what’s happening and post my findings here for others.

All right, I confirmed removing .requires_grad_() works. So just do this moving forward for anyone facing the same issue.

1 Like