RuntimeError with custom function in LSTMCell

Hello,

I have tried to combine a custom function with an LSTMCell. I copied the time sequence prediction code (https://github.com/pytorch/examples/tree/master/time_sequence_prediction) and added a layer in between the LSTM calls, but I got the error “RuntimeError: could not compute gradients for some functions”.

What I did what apply my custom function to one of the output of the LSTM cell; essentially I changed
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
to
h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t = self.myFun(h_t)

and I got the error reported above.

This is the code of the custom function (it is a dummy example):

class multiply(torch.autograd.Function):
    def __init__(self, s):
        super(multiply, self).__init__()
        self.s = s

    def forward(self, input):
        return input*self.s

    def backward(self, grad_output):
        print('I have been called!!')
        return grad_output*self.s

class ML(torch.nn.Module):

    def __init__(self, s):
        super(ML, self).__init__()
        self.m = multiply(s)

    def forward(self, input):
        output = self.m(input)
        return output

And this is the code for the model:

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 1)
        self.myFun = ML(8)

    def forward(self, input, future = 0):
        outputs = []
        h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
        c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
        h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
        c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t = self.myFun(h_t)   ### <------- I added this line
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            outputs += [c_t2]
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            outputs += [c_t2]
        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs

We can compute the forward pass of this model correctly, but when I call loss.backward(), it results in “RuntimeError: could not compute gradients for some functions”.

It is also worth noting that the backward() method of the function I implemented is never called.

Hello,

apparently the problem is that I have to create a new instance of the custom function every time before applying to an input - otherwise, something in the backward computation goes wrong. So instead of saving the function as

self.m = multiply(s)

and calling the forward as

return self.m(input)

we should do directly

return multiply(s)(input)

Now it works!