I am attempting to integrate this scipy function into PyTorch:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.UnivariateSpline.html
I’ve been looking at this guide:
https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html
This is the code I have written so far:
from scipy.interpolate import UnivariateSpline
import torch
from torch.autograd import Function
class SPLFunction(Function):
@staticmethod
def forward(ctx, input, s):
np_input = input.detach().numpy()
s = s.detach()
x, y = np_input[0], np_input[1]
spl = UnivariateSpline(x, y,s=s)
result = spl(y)
ctx.save_for_backward(input, s)
return torch.as_tensor(result, dtype=input[0].dtype)
@staticmethod
def backward(ctx, grad_output):
input, s = ctx.saved_tensors
np_input = input.detach().numpy()
s = s.detach()
x, y = np_input[0], np_input[1]
spl = UnivariateSpline(x, y,s=s)
df = spl.derivative(x)
df = torch.from_numpy(df*grad_output)
return(df)
However, this does not seem to be working. I am getting a variety of errors…if anyone has experience integrating scipy functions in to PyTorch, any tips in the right direction would be hugely appreciated!