Hello,
I have tried to combine a custom function with an LSTMCell. I copied the time sequence prediction code (https://github.com/pytorch/examples/tree/master/time_sequence_prediction) and added a layer in between the LSTM calls, but I got the error “RuntimeError: could not compute gradients for some functions”.
What I did what apply my custom function to one of the output of the LSTM cell; essentially I changed
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
to
h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t = self.myFun(h_t)
and I got the error reported above.
This is the code of the custom function (it is a dummy example):
class multiply(torch.autograd.Function):
def __init__(self, s):
super(multiply, self).__init__()
self.s = s
def forward(self, input):
return input*self.s
def backward(self, grad_output):
print('I have been called!!')
return grad_output*self.s
class ML(torch.nn.Module):
def __init__(self, s):
super(ML, self).__init__()
self.m = multiply(s)
def forward(self, input):
output = self.m(input)
return output
And this is the code for the model:
class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 1)
self.myFun = ML(8)
def forward(self, input, future = 0):
outputs = []
h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t = self.myFun(h_t) ### <------- I added this line
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
for i in range(future):# if we should predict the future
h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs
We can compute the forward pass of this model correctly, but when I call loss.backward(), it results in “RuntimeError: could not compute gradients for some functions”.
It is also worth noting that the backward() method of the function I implemented is never called.