Hi !
I’m working on a project that has the following pipeline:

numerical integration of a given function

interpolation (using Scipy rectBivariateSpline)

inference: given a new tuple (x1, x2) use the interpolation to predict (y)

Note the first two part can be done offline, and the spline function saved for future use (factoring the cost of numerical integration + spline fitting).
I’m looking to integrate the inference part of this pipeline in a Pytorch model, with the support for running it on GPU and allowing learning by backpropagation to (x1, x2).

I’m not so sure the torch interpolate function is useful for that task. Would anyone happen to have an idea on how to tackle this issue ?
Thanks

I am not aware of a built-in spline function being offered by pytorch.

If I understand your use case correctly, it is acceptable to fit the spline
function once, in advance of using pytorch, but you want to be able to
evaluate that pre-fit spline and backpropagate through it within the pytorch
framework.

I see two approaches:

Fit the spline with scipy, and then extract the spline coefficients. Implement
the spline evaluation using (differentiable) pytorch tensor functions. You
will then get gpu support and autograd (backpropagation) “for free.”

Or:

Package your spline evaluation as a custom autograd Function. Use
scipy to evaluate the spline in your Function's forward() method.

Because you use scipy rather than pytorch for this evaluation, you won’t
get autograd / backpropagation for free. So you also have to implement
your Function's backward() method, used by autograd for its gradient
computation. I would suggest having scipy compute the spline’s derivative
(together with its value) in forward() and save it in ctx for subsequent
use in backward().

Note, wrapping scipy’s “forward” / “backward” spline evaluation in a custom
autorad Function won’t get you gpu support for free. However, it’s possible
that an optimized scipy cpu spline implementation could be adequate.

It shouldn’t be very hard to implement the spline evaluation in pytorch, but
I’m lazy, so I would probably lean towards wrapping the scipy implementation
in a Function.