How can I get higher order gradient on GRUCell?

I am working on a network need regularize the gradient, so I must get a second derivative.

But I got an error of

RuntimeError: the derivative for _cudnn_rnn_backward is not implemented

I minimize my code to reproduce the error

cell = nn.GRUCell(10, 10).cuda()
parameters = list(cell.parameters())
x = torch.rand(1, 10).cuda()
y = torch.rand(1, 10).cuda()
incoming = cell(x, torch.zeros(1, 10).cuda())
incoming = cell(y, incoming)
loss = torch.sum(incoming)
grad_all = grad(loss, parameters, retain_graph=True, create_graph=True, only_inputs=True)
loss2 = torch.sum([v.view(-1) for v in grad_all]))

and come up with another error

RuntimeError: trying to differentiate twice a function that was markedwith @once_differentiable

So is there any workaround for me to get the second order gradient? (I’m on pytorch0.4.1)

I’d probably implement a gru cell from nn.Linear.

Best regards


Thanks a lot. It works.

So it’s the cuda version of GRUCell marked with once_differentiable. But I hope there is a more precise error, and a simpler way to use the GRUCell rather than implement it myself when I want to get the second order gradient.

I’m all for it and implementing wonderful things for the various RNNs is on my “things I’d like to do when I get to it list”, but I can’t just promise when that will be…

Best regards


@hzhwcmhf , I am stuck with the exact same problem. Would you mind sharing your GRU/LSTM cell implementation with nn.Linear()? Thanks in advance

I think this may help.


Thanks @hzhwcmhf. That was very helpful.

Just letting everyone know that following Thomas’s suggestion in another thread, I whipped up a JIT-based implementation a while back that works much faster than using nn.Linear.

It’s available here:

Maybe this is better for you:

cell = torch.nn.GRU(10, 10)
x = torch.rand(1, 10)
incoming, _ = cell(x.view(1, 1, 10), torch.zeros(1, 1, 10))
loss = torch.sum(incoming)
grad_all=torch.autograd.grad(loss, cell.parameters(), create_graph=True)
loss2 = torch.sum([v.view(-1) for v in grad_all]))
grad2=torch.autograd.grad(loss2, cell.parameters(), create_graph=True)