Runtime error caused by dependency engine?

Test script:

import torch
import torch.nn as nn
from torch.autograd import Variable

class Net(nn.Module):

    def __init__(self, **config):
        super(Net, self).__init__()
        self.config = config
        self.embedding = nn.Embedding(config['vocab_size'], config['embedding_size'])
        self.rnn = nn.LSTM(
            input_size = config['code_size'] + config['embedding_size'],
            hidden_size = config['hidden_size'],
            num_layers = config['num_layers'],
            dropout = config['dropout_ratio'],
        )
        self.linear = nn.Linear(config['hidden_size'], config['vocab_size'])
        self.softmax = nn.Softmax()

    def forward(self, code, step):
        batch_size = code.size()[0]
        prev_index = Variable(torch.LongTensor(batch_size).fill_(self.config['beg_index']))
        prev_h = prev_c = Variable(torch.zeros(self.config['num_layers'], batch_size, self.config['hidden_size']))
        logits = []
        for i in range(step):
            prev_vector = self.embedding(prev_index)
            curr_input = torch.cat((code, prev_vector), 1)
            curr_input = curr_input.view(1, *curr_input.size())
            curr_output, (curr_h, curr_c) = self.rnn(curr_input)
            prev_h, prev_c = curr_h, curr_c
            logit = self.linear(curr_output.squeeze())
            prev_index = torch.max(logit, 1)[1].squeeze()
            logits.append(logit)
        shape = (len(logits),) + logits[0].size()
        logit = torch.cat(logits, 0)
        prob = self.softmax(logit).view(*shape)
        return prob

net = Net(
    code_size = 100,
    hidden_size = 50,
    num_layers = 2,
    dropout_ratio = 0,
    vocab_size = 1000,
    embedding_size = 10,
    beg_index = 1,
)
code = Variable(torch.FloatTensor(14, 100))
prob = net(code, 10)
prob.backward(torch.ones(prob.size()))

Execution result (note the last line):

Traceback (most recent call last):
  File "test.py", line 50, in <module>
    prob.backward(torch.ones(prob.size()))
  File "/Users/warbean/anaconda3/envs/py35/lib/python3.5/site-packages/torch/autograd/variable.py", line 158, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: could not compute gradients for some functions (Linear, Linear, Linear, Linear, Linear, Linear, Linear, Linear, Linear)

Then I found the error info “could not compute gradients for some functions” is located at https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp#L276, caused by assertion failure:

THPUtils_assert(not_ready.empty(), "could not compute gradients for some functions (%s)", names.c_str());

Is there something wrong in PyTorch’s dependency engine or in my usage?

(I’m not sure whether there is a “dependency engine” concept in PyTorch. I just borrow it from MXNet.)

Yeah we call that backward engine. It’s a bug indeed. Use this workaround for now, and I’ll fix it today:

prev_index = Variable(torch.max(logit.data, 1)[1].squeeze())

I now see that the problem arises because the Linear layer will get gradients only from the last iteration of the loop (it is followed by non-differentiable argmax in all other cases). Not sure if that’s desired, just wanted to give you a heads up.

I’ve verified that the bug can only unnecessarily raise errors - it doesn’t affect correctness. There’s a big PR that touches the code that I’d need to change to fix it, and since the workaround is quite simple, I’m putting this on hold until it’s merged. I’ve opened an issue.

Thank you for the workaround! It runs without error.

However I don’t understand what is “get gradients only from the last iteration of the loop”. What I desire is to propagate gradients into each output step, throughout all the sequence to the beginning.

Should I propagate into each output step indivisually? Like this:

for prob in probs:
    prob.backward(gradient, retain_variables = True)

Sorry, nevermind my comment. I only visualized the graph for a single iteration. It’s all fine.