Make a scipy function twice differentiable using autograd


I am trying to compute gradients twice with autograd.grad like this :

import scipy
import scipy.special
import torch
import torch.nn as nn
import torch.autograd
from torch.autograd import Function, Variable
import torch.nn.functional as F

class Policy_Beta(nn.Module):
    def __init__(self):
        super(Policy_Beta, self).__init__()
        self.conv1 = nn.Conv2d(12, 16, 3)
        self.conv2 = nn.Conv2d(16, 5, 3)
        self.affine1 = nn.Linear(5 * 4 * 4, 64)
        self.a = nn.Linear(64, 2)
        self.b = nn.Linear(64, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.affine1(x))
        a = F.relu(self.a(x))
        b = F.relu(self.b(x))

        return a, b 

class Polygamma(Function):
    def forward(self, n, input):
        result = torch.Tensor(scipy.special.polygamma(n.numpy(), input.numpy()))
        self.save_for_backward(input, n)
        return result

    def backward(self, grad_input):
        input, n = self.saved_tensors
        return torch.Tensor([0]), grad_input * (Polygamma()(Variable(n + 1), Variable(input))).data

policy_net = Policy_Beta()

a, b = policy_net(Variable(torch.Tensor(14, 12, 8, 8)))

z = Variable(torch.Tensor([0]))

k_1 =(Polygamma()(z, a)*Polygamma()(z, b)).mean()
grads = torch.autograd.grad(k_1, policy_net.parameters(), create_graph=True)

k_2 =[grad.view(-1) for grad in grads]).mean()
grads = torch.autograd.grad(k_2, policy_net.parameters())

But it’s giving me the following error :

    RuntimeError                      Traceback (most recent call last)

         51 k_2 =[grad.view(-1) for grad in grads]).mean()
    ---> 52 grads = torch.autograd.grad(k_2, policy_net.parameters())    

    ~torch/autograd/ in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs)
        151     return Variable._execution_engine.run_backward(
        152         outputs, grad_outputs, retain_graph,
    --> 153         inputs, only_inputs)
        155 status = torch._C._autograd_init()

    RuntimeError: Polygamma is not differentiable twice

I am assuming that it’s because I am using scipy.special.polygamma. Any idea how I could make this work ? Autograd for numpy / scipy functions ref

I thought about not using spicy but coding the polygamma function directly but it’s seems rather difficult (considering that it requires derivatives and integrals)

Any help would be much appreciated !