 Linear interpolation in pytorch

Hi all.
I am newbei to the pytorch. is there any interpolation (linear) function that is similar to the np.interp function.
Please let me know if you need any details? Thanks in advance.

Hi,

I’m afraid there is none.
If you’re fine with a non-differentiable, cpu-only implementation, you can use the numpy one as a workaround: res = torch.from_numpy(np.interp(x.numpy(), fx.numpy(), fp.numpy()).

1 Like

Thank you for your reply. I was looking for the one with differentiable as it is used to predict the output.

We don’t have one no.
If you have a favorite algorithm to implement it, I’m sure you can do a differentiable version by using pytorch’s primitives.

1 Like

Try torch.nn.functional.interpolate() using mode ‘linear’. It is differentiable.

1 Like

Actually, torch.nn.functional.grid_sample() might be closer to what you’re looking for.

2 Likes

I just wrote a package that implements this (GPU-compatible & differentiable). https://github.com/sbarratt/torch_interpolations

3 Likes

Here’s an implementation by @aliutkus
Like np.interp, it supports non-regular grid.

def interp1d(y: torch.Tensor, newx: torch.Tensor):
'''
Function for simple linear interpolation in pytorch.
Assumes x is [0, 1, 2, 3, ..., len(y)-1]
newx can be any dimension and the dimension of result will match newx's.

Test:
y = torch.arange(24, dtype=torch.float)
newx = torch.tensor([1.3, 2.5])
newx = torch.arange(24).reshape(2,3,4) + 0.23
newx = torch.arange(24).reshape(2,3,4)
'''
assert len(y.shape) == 1 # 1-dimensional
assert len(y) > 1
assert newx.min() >= 0
assert newx.max() <= len(y) - 1

npoints = len(y)
ndim = len(newx.shape)

x = newx.unsqueeze(ndim)
x0 = torch.arange(npoints-1)
diff = (x - x0).to(torch.float)
bin = (diff >= 0) & (diff < 1)

y0 = y.to(torch.float)
y0left = y[:-1].expand(diff.shape)
y0right = y[1:].expand(diff.shape)
y0max = y0[-1].expand(diff.shape) / (npoints - 1) # Averaging because it gets summed (npoints - 1) times later.

weight = diff * bin
interp = torch.lerp(y0left, y0right, weight) * bin
interp = interp + (x == (npoints - 1)) * y0max # The corner case

result = interp.sum(dim=ndim)

return result