Hi, i want to create a (derivable) interpolation function, that takes calculated NxCxHxW tensor, and can upsample either H, W or both dimensions to be DIM*2 or DIM*2 - 1, while keeping the previous data points.
I tried to use the in-built bilinear mode of interpolate function, but it does not appear to have a way of keeping the input data as part of output, which is necessary for me.
Is there a better way to calculate upscaling in such a way?
I have implemented this with rolling the tensor, see below, however to do it in both dimensions simultaneously slows down the function significantly, compared to in-built interpolate, despite technically being less calculation.
(this is important, as i am optimising this task)
PS: is this even the correct tag for this issue on the forums, or should i repost it somewhere else?
def interpolate_keep_values(input_tensor, desired_shape):
"""
:param input_tensor:
:param desired_shape: last two dimensions of the output tensor, only accepts `in_dim * 2` and `(in_dim * 2) - 1` due to
requirement of keeping the same calculated values. In case of `in_dim * 2` shape, the last value is duplicated due to missing information.
:return: interpolated tensor of the desired shape
Ideally, the `(in_dim * 2)` outputs should have the edge duplicated, but that costs significant extra time.
"""
#
orig_x, orig_y = input_tensor.shape[-2:]
out_x, out_y = desired_shape
first = (out_x - orig_x) > 1
second = (out_y - orig_y) > 1
dblx = not (orig_x * 2) > out_x
dbly = not (orig_y * 2) > out_y
#Yes, i *could* prettify the code by not duplicating it in every if, but the decision process is more clearly visible
#and the parameter switching would have been ugly
if first:# first dim
off = torch.roll(input_tensor, 1, dims=-2)
mids = ((input_tensor + off)/2)
if second:#second dim
if dblx:
if dbly:
#out_tensor = torch.zeros((input_tensor.shape[0], input_tensor.shape[1], out_x, out_y))
out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], out_x,
orig_x), -1, dims=-2)
off2 = torch.roll(out_tensor, 1, dims=-1)
mids2 = ((out_tensor + off2) / 2)
out_tensor = torch.roll(
torch.stack((mids2, out_tensor), dim=-1).view(
input_tensor.shape[0], input_tensor.shape[1], out_x, out_y), -1, dims=-1)
#out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
#out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
return out_tensor
else:
out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], out_x,
orig_x), -1, dims=-2)
off2 = torch.roll(out_tensor, 1, dims=-1)
mids2 = ((out_tensor + off2) / 2)
out_tensor = torch.roll(
torch.stack((mids2, out_tensor), dim=-1).view(
input_tensor.shape[0], input_tensor.shape[1], out_x, orig_y*2), -1, dims=-1)
#out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
return out_tensor[:, :, :, :out_y]
else:
if dbly:
out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], orig_x*2,
orig_x), -1, dims=-2)
off2 = torch.roll(out_tensor, 1, dims=-1)
mids2 = ((out_tensor + off2) / 2)
out_tensor = torch.roll(
torch.stack((mids2, out_tensor), dim=-1).view(
input_tensor.shape[0], input_tensor.shape[1], orig_x*2, out_y), -1, dims=-1)
#out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
return out_tensor[:, :, :out_x, :]
else:
out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], 2*orig_x, orig_x), -1, dims=-2)
off2 = torch.roll(out_tensor, 1, dims=-1)
mids2 = ((out_tensor + off2) / 2)
return torch.roll(
torch.stack((mids2, out_tensor), dim=-1).view(
input_tensor.shape[0], input_tensor.shape[1], 2*orig_x, 2*orig_y), -1, dims=-1)[:, :, :out_x, :out_y]
else:
if dblx:
out_tensor = torch.roll(
torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], out_x, orig_y), -1, dims=-2)
return out_tensor
else:
out_tensor = torch.roll(torch.stack((mids, input_tensor), dim=-2)
.view(input_tensor.shape[0], input_tensor.shape[1], orig_x*2, orig_x), -1, dims=-2)
#out_tensor[:, :, -1, :] = out_tensor[:, :, -2, :]
return out_tensor[:, :, :out_x, :]
elif second:
off2 = torch.roll(input_tensor, 1, dims=-1)
mids2 = ((input_tensor + off2) / 2)
if dbly:
out_tensor = torch.roll(
torch.stack((mids2, input_tensor), dim=-1)
.view(input_tensor.shape[0], input_tensor.shape[1], orig_x, out_y), -1, dims=-1)
out_tensor[:, :, :, -1] = out_tensor[:, :, :, -2]
return out_tensor
else:
out_tensor = torch.roll(
torch.stack((mids2, input_tensor), dim=-1)
.view(input_tensor.shape[0], input_tensor.shape[1], orig_y, orig_y*2), -1, dims=-1)
return out_tensor[:, :, :, :out_y]
return input_tensor