Linear interpolation in pytorch

Hi all.
I am newbei to the pytorch. is there any interpolation (linear) function that is similar to the np.interp function.
Please let me know if you need any details? Thanks in advance.

Hi,

I’m afraid there is none.
If you’re fine with a non-differentiable, cpu-only implementation, you can use the numpy one as a workaround: res = torch.from_numpy(np.interp(x.numpy(), fx.numpy(), fp.numpy()).

1 Like

Thank you for your reply. I was looking for the one with differentiable as it is used to predict the output.

We don’t have one no.
If you have a favorite algorithm to implement it, I’m sure you can do a differentiable version by using pytorch’s primitives.

1 Like

Try torch.nn.functional.interpolate() using mode ‘linear’. It is differentiable.

1 Like

Actually, torch.nn.functional.grid_sample() might be closer to what you’re looking for.

2 Likes

I just wrote a package that implements this (GPU-compatible & differentiable). https://github.com/sbarratt/torch_interpolations

3 Likes

Here’s an implementation by @aliutkus
Like np.interp, it supports non-regular grid.

I wrote the following function. Your comments would be much appreciated:

def interp1d(y: torch.Tensor, newx: torch.Tensor): 
    '''
    Function for simple linear interpolation in pytorch. 
    Assumes x is [0, 1, 2, 3, ..., len(y)-1]
    newx can be any dimension and the dimension of result will match newx's. 

    Test: 
    y = torch.arange(24, dtype=torch.float)
    newx = torch.tensor([1.3, 2.5])
    newx = torch.arange(24).reshape(2,3,4) + 0.23
    newx = torch.arange(24).reshape(2,3,4) 
    '''
    assert len(y.shape) == 1 # 1-dimensional
    assert len(y) > 1
    assert newx.min() >= 0
    assert newx.max() <= len(y) - 1

    npoints = len(y)
    ndim = len(newx.shape)

    x = newx.unsqueeze(ndim)
    x0 = torch.arange(npoints-1)
    diff = (x - x0).to(torch.float)
    bin = (diff >= 0) & (diff < 1)

    y0 = y.to(torch.float)
    y0left = y[:-1].expand(diff.shape)
    y0right = y[1:].expand(diff.shape)
    y0max = y0[-1].expand(diff.shape) / (npoints - 1) # Averaging because it gets summed (npoints - 1) times later.

    weight = diff * bin
    interp = torch.lerp(y0left, y0right, weight) * bin
    interp = interp + (x == (npoints - 1)) * y0max # The corner case
    
    result = interp.sum(dim=ndim)

    return result

Hi, I faced the same issue and implemented a basic solution for the 1D case. Maybe it helps or can serve as a baseline for higher order interpolation :slight_smile:

import matplotlib.pyplot as plt
import torch
import numpy as np


# Generate random sorted and unique x values in the range from -21 to 19 and corresponding y values
x = ((50*torch.rand(60))-25).clamp(min=-21, max=19).unique()
y = torch.rand_like(x)
# Set the support points to the range [-25, 25]
x_new = torch.arange(-25, 25+2, 2)


def interpolate_to_support(x: torch.Tensor, y: torch.Tensor, support: torch.Tensor) -> torch.Tensor:
    """
    This is a rudimentary implementation of numpy.interp for the 1D case only. If the x values are not unique, this behaves differently to np.interp.
    :param x: The original coordinates.
    :param y: The original values.
    :param support: The support points to which y shall be interpolated.
    :return:
    """
    # Evaluate the forward difference for all except the edge points
    slope = torch.zeros_like(x)
    slope[1:-1] = ((y[1:] - y[:-1]) / (x[1:] - x[:-1]))[1:]

    # Evaluate which of the support points are within the range of x
    support_nonzero_mask = (support >= x.min()) & (support <= x.max())
    # Subset the support points accordingly
    support_nonzero = support[support_nonzero_mask]
    # Get the indices of the closest point to the left for each support point
    support_insert_indices = torch.searchsorted(x, support_nonzero)
    # Get the offset from the point to the left to the support point
    support_nonzero_offset = support_nonzero - x[support_insert_indices]
    # Calculate the value for the nonzero support: value of the point to the left plus slope times offset
    support_nonzero_values = y[support_insert_indices] + slope[support_insert_indices-1] * support_nonzero_offset

    # Create the output tensor and place the nonzero support
    support_values = torch.zeros_like(support).float()
    support_values[support_nonzero_mask] = support_nonzero_values

    return support_values


plt.plot(x_new, interpolate_to_support(x, y, x_new), "go", label="Custom interpolation")
plt.plot(x_new, np.interp(x_new, x, y, left=0, right=0), "y-", label="np.interp")
plt.plot(x, y, "b--", label="original values")
plt.legend()
plt.show()

Cheers :wave:

Hi, I’ve create a basic solution for arbitrary dimension (assumed that the last dimension is used) :smile: This implementation only use simple operation like max, min, abs, gather, etc. I tried to avoid torch.searchsorted as they aren’t support for onnx exporting. A big disadvantage is my implementation can be quite inefficient and memory hungry when dealing with large tensor with too much points (either the x or xp, like few thousand for each or the batch size is too large). Otherwise, this should be really fast.

def interp(x: torch.Tensor, y: torch.Tensor, xp: torch.Tensor):
    ''' 
    x : [..., N]
    y : [..., N]
    xp: [..., P]
    '''
    x_min, min_indices = torch.min(x, dim= -1, keepdim = True)
    x_max, max_indices = torch.max(x, dim= -1, keepdim = True)

    y_min = torch.gather(y, -1, min_indices)
    y_max = torch.gather(y, -1, max_indices)

    xp_min = torch.amin(xp, -1, keepdim= True)
    xp_max = torch.amax(xp, -1, keepdim= True)

    x = torch.cat([torch.minimum(x_min, xp_min), x, torch.maximum(x_max, xp_max)], dim = -1)
    y = torch.cat([y_min, y, y_max], dim = -1)

    diff = (x.unsqueeze(-2) - xp.unsqueeze(-1)) # [..., P, N]
    lhs = (diff <= 0).float() # x <= xp

    dist = diff.abs()

    left_dist, left_idx = (dist * lhs + torch.ones_like(dist).mul(dist.amax((-1,-2), keepdim= True)).mul(10) * (1-lhs)).min(-1, keepdim= False)  # [..., P]
    right_dist, right_idx = (dist * (1-lhs) + torch.ones_like(dist).mul(dist.amax((-1,-2), keepdim= True)).mul(10) * lhs).min(-1, keepdim= False)  # [..., P]

    left_y = torch.gather(y, -1, left_idx)
    right_y = torch.gather(y, -1, right_idx)

    yp = left_y + left_dist/(left_dist + right_dist) * (right_y - left_y)

    return yp

Note: The memory inefficient can be solved using sorting the x, but will cause another overhead to call torch.sort. For anyone who would want to optimized using torch.searchsorted, this is the equivalent version:

def robust_interp(x: torch.Tensor, y: torch.Tensor, xp: torch.Tensor):
    ''' 
    x : [..., N]
    y : [..., N]
    xp: [..., P]
    '''
    x_min, min_indices = torch.min(x, dim= -1, keepdim = True)
    x_max, max_indices = torch.max(x, dim= -1, keepdim = True)

    y_min = torch.gather(y, -1, min_indices)
    y_max = torch.gather(y, -1, max_indices)

    xp_min = torch.amin(xp, -1, keepdim= True)
    xp_max = torch.amax(xp, -1, keepdim= True)

    ## Handle the case where out of bound value in support
    x = torch.cat([torch.minimum(x_min, xp_min), x, torch.maximum(x_max, xp_max)], dim = -1)
    y = torch.cat([y_min, y, y_max], dim = -1)

    x_sorted, sorted_idx = torch.sort(x, dim = -1)
    y_sorted = torch.gather(y, -1, sorted_idx)

    right_idx = torch.searchsorted(x_sorted, xp)
    left_idx = right_idx.sub(1).clamp(0, x.shape[-1]-1)

    left_dist = xp - torch.gather(x_sorted, -1, left_idx)
    right_dist = torch.gather(x_sorted, -1, right_idx) - xp

    left_y = torch.gather(y_sorted, -1, left_idx)
    right_y = torch.gather(y_sorted, -1, right_idx)

    yp = left_y + left_dist/(left_dist + right_dist) * (right_y - left_y)

    return yp

To avoid nan gradients epsilon added

def torch_interpolate(x, x_points, y_points,eps=1e-7):
    """
    diffrentiable version of 'np.interp' function in torch
    Notes:
        - x_points and y_points should be 1D verctors of coresponding points.
        - x_points should be sorted.
        - x can be of any shape
    """
    right_idx = torch.searchsorted(x_points, x)
    left_idx = (right_idx - 1).clamp(0)
    right_idx.clamp_(None, x_points.shape[0] - 1)

    left_dist = x - x_points[left_idx] 
    right_dist = x_points[right_idx] - x 

    left_y = y_points[left_idx]
    right_y = y_points[right_idx]

    interpolated_y = left_y + left_dist / (left_dist + right_dist + eps) * (right_y - left_y)
    return interpolated_y