Differentiate twice through a custom pytorch extension using numpy

Hello all,

I’d like to differentiate twice through a PyTorch extension that I implemented manually using numpy and
scipy. My objective is to compute a jvp/rop of this block as done here: Rop.py · GitHub

Is this possible? Should I manually define additional derivatives, and where?

The example I am really interested in is a bit lengthy, below there is a similar one based on the
tutorial example Creating Extensions Using numpy and scipy — PyTorch Tutorials 1.9.0+cu102 documentation

import torch
from torch.autograd import Function
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter, bias):
        # detach so we can cast to NumPy
        input, filter, bias = input.detach(), filter.detach(), bias.detach()
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(input, filter, bias)
        return torch.as_tensor(result, dtype=input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        input, filter, bias = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_bias = np.sum(grad_output, keepdims=True)
        grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
        # the previous line can be expressed equivalently as:
        # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
        grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
        return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)


class ScipyConv2d(Module):
    def __init__(self, filter_width, filter_height):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(filter_width, filter_height))
        self.bias = Parameter(torch.randn(1, 1))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter, self.bias)


if __name__ == "__main__":
    module = ScipyConv2d(3, 3)
    input = torch.randn(10, 10)
    output = module(input)

    w = torch.ones_like(output, requires_grad=True)
    tmp = torch.autograd.grad(output, module.parameters(), w, create_graph=True)
    # cannot differentiate through tmp! No grad_fn defined

The mechanic behind double backward is that the operations performed during the initial backward are recorded by autograd.

This means that since you use numpy in the backward of ScipyConv2dFunction, it won’t be able to double backward out of the box.

For this to work you can implement the backward of ScipyConv2dFunction as another custom function ScipyConv2dFunctionBackward.

Let me know if this example makes sense, or are there more details you’d like to know.

import torch

def cube_forward(x):
    return x**3

def cube_backward(grad_out, x):
    return grad_out * 3 * x**2

def cube_backward_backward(grad_out, sav_grad_out, x):
    return grad_out * sav_grad_out * 6 * x

def cube_backward_backward_grad_out(grad_out, x):
    return grad_out * 3 * x**2

class Cube(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return cube_forward(x)

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        return CubeBackward.apply(grad_out, x)

class CubeBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, grad_out, x):
        ctx.save_for_backward(x, grad_out)
        return cube_backward(grad_out, x)

    @staticmethod
    def backward(ctx, grad_out):
        x, sav_grad_out = ctx.saved_tensors
        dx = cube_backward_backward(grad_out, sav_grad_out, x)
        dgrad_out = cube_backward_backward_grad_out(grad_out, x)
        return dgrad_out, dx

x = torch.tensor(2., requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x). # verify that it works w/ gradgradcheck
2 Likes

Dear @soulitzer,

Thanks a lot for the answer! The operations are clear to me.

Marco

1 Like