How can I get higher order gradient on GRUCell?

I am working on a network need regularize the gradient, so I must get a second derivative.

But I got an error of

RuntimeError: the derivative for _cudnn_rnn_backward is not implemented

I minimize my code to reproduce the error

cell = nn.GRUCell(10, 10).cuda()
parameters = list(cell.parameters())
x = torch.rand(1, 10).cuda()
y = torch.rand(1, 10).cuda()
incoming = cell(x, torch.zeros(1, 10).cuda())
incoming = cell(y, incoming)
loss = torch.sum(incoming)
grad_all = grad(loss, parameters, retain_graph=True, create_graph=True, only_inputs=True)
print(grad_all[0].requires_grad)
loss2 = torch.sum(torch.cat([v.view(-1) for v in grad_all]))
loss2.backward()

and come up with another error

RuntimeError: trying to differentiate twice a function that was markedwith @once_differentiable

So is there any workaround for me to get the second order gradient? (I’m on pytorch0.4.1)

I’d probably implement a gru cell from nn.Linear.

Best regards

Thomas

Thanks a lot. It works.

So it’s the cuda version of GRUCell marked with once_differentiable. But I hope there is a more precise error, and a simpler way to use the GRUCell rather than implement it myself when I want to get the second order gradient.

I’m all for it and implementing wonderful things for the various RNNs is on my “things I’d like to do when I get to it list”, but I can’t just promise when that will be…

Best regards

Thomas

@hzhwcmhf , I am stuck with the exact same problem. Would you mind sharing your GRU/LSTM cell implementation with nn.Linear()? Thanks in advance

I think this may help. https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb

2 Likes

Thanks @hzhwcmhf. That was very helpful.

Just letting everyone know that following Thomas’s suggestion in another thread, I whipped up a JIT-based implementation a while back that works much faster than using nn.Linear.

It’s available here: https://github.com/Maghoumi/JitGRU

Maybe this is better for you:

cell = torch.nn.GRU(10, 10)
x = torch.rand(1, 10)
incoming, _ = cell(x.view(1, 1, 10), torch.zeros(1, 1, 10))
loss = torch.sum(incoming)
grad_all=torch.autograd.grad(loss, cell.parameters(), create_graph=True)
loss2 = torch.sum(torch.cat([v.view(-1) for v in grad_all]))
grad2=torch.autograd.grad(loss2, cell.parameters(), create_graph=True)