Hello all,
I’d like to differentiate twice through a PyTorch extension that I implemented manually using numpy and
scipy. My objective is to compute a jvp/rop of this block as done here: Rop.py · GitHub
Is this possible? Should I manually define additional derivatives, and where?
The example I am really interested in is a bit lengthy, below there is a similar one based on the
tutorial example Creating Extensions Using numpy and scipy — PyTorch Tutorials 1.9.0+cu102 documentation
import torch
from torch.autograd import Function
import numpy as np
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
class ScipyConv2dFunction(Function):
@staticmethod
def forward(ctx, input, filter, bias):
# detach so we can cast to NumPy
input, filter, bias = input.detach(), filter.detach(), bias.detach()
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
result += bias.numpy()
ctx.save_for_backward(input, filter, bias)
return torch.as_tensor(result, dtype=input.dtype)
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.detach()
input, filter, bias = ctx.saved_tensors
grad_output = grad_output.numpy()
grad_bias = np.sum(grad_output, keepdims=True)
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
# the previous line can be expressed equivalently as:
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
class ScipyConv2d(Module):
def __init__(self, filter_width, filter_height):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(filter_width, filter_height))
self.bias = Parameter(torch.randn(1, 1))
def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
if __name__ == "__main__":
module = ScipyConv2d(3, 3)
input = torch.randn(10, 10)
output = module(input)
w = torch.ones_like(output, requires_grad=True)
tmp = torch.autograd.grad(output, module.parameters(), w, create_graph=True)
# cannot differentiate through tmp! No grad_fn defined