Interpolation for size doubling with keeping previous data points

Hi, i want to create a (derivable) interpolation function, that takes calculated NxCxHxW tensor, and can upsample either H, W or both dimensions to be DIM*2 or DIM*2 - 1, while keeping the previous data points.

I tried to use the in-built bilinear mode of interpolate function, but it does not appear to have a way of keeping the input data as part of output, which is necessary for me.

Is there a better way to calculate upscaling in such a way?
I have implemented this with rolling the tensor, see below, however to do it in both dimensions simultaneously slows down the function significantly, compared to in-built interpolate, despite technically being less calculation.
(this is important, as i am optimising this task)

PS: is this even the correct tag for this issue on the forums, or should i repost it somewhere else?


def interpolate_keep_values(input_tensor, desired_shape):
    """
    :param input_tensor:
    :param desired_shape: last two dimensions of the output tensor, only accepts `in_dim * 2` and `(in_dim * 2) - 1` due to
    requirement of keeping the same calculated values. In case of `in_dim * 2` shape, the last value is duplicated due to missing information.
    :return: interpolated tensor of the desired shape
    Ideally, the `(in_dim * 2)` outputs should have the edge duplicated, but that costs significant extra time.
    """
    #
    orig_x, orig_y = input_tensor.shape[-2:]
    out_x, out_y = desired_shape
    first = (out_x - orig_x) > 1
    second = (out_y - orig_y) > 1
    dblx = not (orig_x * 2) > out_x
    dbly = not (orig_y * 2) > out_y
    #Yes, i *could* prettify the code by not duplicating it in every if, but the decision process is more clearly visible
    #and the parameter switching would have been ugly
    if first:# first dim
        off = torch.roll(input_tensor, 1, dims=-2)
        mids = ((input_tensor + off)/2)
        if second:#second dim
            if dblx:
                if dbly:
                    #out_tensor = torch.zeros((input_tensor.shape[0], input_tensor.shape[1], out_x, out_y))
                    out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
                                                          .view(input_tensor.shape[0], input_tensor.shape[1], out_x,
                                                                orig_x), -1, dims=-2)
                    off2 = torch.roll(out_tensor, 1, dims=-1)
                    mids2 = ((out_tensor + off2) / 2)
                    out_tensor = torch.roll(
                                            torch.stack((mids2, out_tensor), dim=-1).view(
                            input_tensor.shape[0], input_tensor.shape[1], out_x, out_y), -1, dims=-1)

                    #out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
                    #out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
                    return out_tensor
                else:

                    out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
                                                          .view(input_tensor.shape[0], input_tensor.shape[1], out_x,
                                                                orig_x), -1, dims=-2)
                    off2 = torch.roll(out_tensor, 1, dims=-1)
                    mids2 = ((out_tensor + off2) / 2)
                    out_tensor = torch.roll(
                        torch.stack((mids2, out_tensor), dim=-1).view(
                            input_tensor.shape[0], input_tensor.shape[1], out_x, orig_y*2), -1, dims=-1)
                    #out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
                    return out_tensor[:, :, :, :out_y]
            else:
                if dbly:
                    out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
                                                          .view(input_tensor.shape[0], input_tensor.shape[1], orig_x*2,
                                                                orig_x), -1, dims=-2)
                    off2 = torch.roll(out_tensor, 1, dims=-1)
                    mids2 = ((out_tensor + off2) / 2)
                    out_tensor = torch.roll(
                        torch.stack((mids2, out_tensor), dim=-1).view(
                            input_tensor.shape[0], input_tensor.shape[1], orig_x*2, out_y), -1, dims=-1)
                    #out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
                    return out_tensor[:, :, :out_x, :]
                else:
                    out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
                                            .view(input_tensor.shape[0], input_tensor.shape[1], 2*orig_x, orig_x), -1, dims=-2)
                    off2 = torch.roll(out_tensor, 1, dims=-1)
                    mids2 = ((out_tensor + off2) / 2)
                    return torch.roll(
                        torch.stack((mids2, out_tensor), dim=-1).view(
                            input_tensor.shape[0], input_tensor.shape[1], 2*orig_x, 2*orig_y), -1, dims=-1)[:, :, :out_x, :out_y]

        else:
            if dblx:
                out_tensor = torch.roll(
                    torch.stack((mids, input_tensor), dim=-2)
                    .view(input_tensor.shape[0], input_tensor.shape[1], out_x, orig_y), -1, dims=-2)
                return out_tensor
            else:
                out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
                                                .view(input_tensor.shape[0], input_tensor.shape[1], orig_x*2, orig_x), -1, dims=-2)
                #out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
                return out_tensor[:, :, :out_x, :]

    elif second:
        off2 = torch.roll(input_tensor, 1, dims=-1)
        mids2 = ((input_tensor + off2) / 2)
        if dbly:
            out_tensor = torch.roll(
                torch.stack((mids2, input_tensor), dim=-1)
                .view(input_tensor.shape[0], input_tensor.shape[1], orig_x, out_y), -1, dims=-1)
            out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
            return out_tensor
        else:
            out_tensor = torch.roll(
                torch.stack((mids2, input_tensor), dim=-1)
                .view(input_tensor.shape[0], input_tensor.shape[1], orig_y, orig_y*2), -1, dims=-1)
            return out_tensor[:, :, :, :out_y]

    return input_tensor