Hi,
I am trying to compute gradients twice with autograd.grad
like this :
import scipy
import scipy.special
import torch
import torch.nn as nn
import torch.autograd
from torch.autograd import Function, Variable
import torch.nn.functional as F
class Policy_Beta(nn.Module):
def __init__(self):
super(Policy_Beta, self).__init__()
self.conv1 = nn.Conv2d(12, 16, 3)
self.conv2 = nn.Conv2d(16, 5, 3)
self.affine1 = nn.Linear(5 * 4 * 4, 64)
self.a = nn.Linear(64, 2)
self.b = nn.Linear(64, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = F.relu(self.affine1(x))
a = F.relu(self.a(x))
b = F.relu(self.b(x))
return a, b
class Polygamma(Function):
def forward(self, n, input):
result = torch.Tensor(scipy.special.polygamma(n.numpy(), input.numpy()))
self.save_for_backward(input, n)
return result
def backward(self, grad_input):
input, n = self.saved_tensors
return torch.Tensor([0]), grad_input * (Polygamma()(Variable(n + 1), Variable(input))).data
policy_net = Policy_Beta()
a, b = policy_net(Variable(torch.Tensor(14, 12, 8, 8)))
z = Variable(torch.Tensor([0]))
k_1 =(Polygamma()(z, a)*Polygamma()(z, b)).mean()
grads = torch.autograd.grad(k_1, policy_net.parameters(), create_graph=True)
k_2 = torch.cat([grad.view(-1) for grad in grads]).mean()
grads = torch.autograd.grad(k_2, policy_net.parameters())
But it’s giving me the following error :
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
51 k_2 = torch.cat([grad.view(-1) for grad in grads]).mean()
---> 52 grads = torch.autograd.grad(k_2, policy_net.parameters())
~torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs)
151 return Variable._execution_engine.run_backward(
152 outputs, grad_outputs, retain_graph,
--> 153 inputs, only_inputs)
154
155 status = torch._C._autograd_init()
RuntimeError: Polygamma is not differentiable twice
I am assuming that it’s because I am using scipy.special.polygamma
. Any idea how I could make this work ? Autograd for numpy / scipy functions ref
I thought about not using spicy but coding the polygamma function directly but it’s seems rather difficult (considering that it requires derivatives and integrals)
Any help would be much appreciated !