Hi! I find upsampling with F.interpolate
and F.grid_sample
leads to inconsistent results.
First, I try to do the downsampling:
# try downsampling
w_src, w_tgt = 8, 4
input = torch.arange(w_src * w_src).view(1, 1, w_src, w_src).float()
# Create grid to upsample input
d = torch.linspace(-1, 1, w_tgt)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim
output_gs_true = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True)
# option-1: correct the wrong grid
grid_false = ((grid + 1) * (w_tgt - 1) + 1) / w_tgt - 1
# option-2: calculate in the correct way from ground-up
# d = torch.arange(w_tgt)
# meshx, meshy = torch.meshgrid((d, d))
# _grid = torch.stack((meshy, meshx), 2)[None]
# grid_false = ((_grid + 1/2) / w_tgt) * 2 - 1
output_gs_false = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False)
output_gs_false_c = torch.nn.functional.grid_sample(input, grid_false, mode='bilinear', align_corners=False)
output_itp_true = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=True)
output_itp_false = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=False)
output_gs_false_c
and output_itp_false
are consistent.
Then, I try to do the upsampling:
# try upsampling
w_src, w_tgt = 4, 8
input = torch.arange(w_src * w_src).view(1, 1, w_src, w_src).float()
# Create grid to upsample input
d = torch.linspace(-1, 1, w_tgt)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim
output_gs_true = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True)
# option-1: correct the wrong grid
grid_false = ((grid + 1) * (w_tgt - 1) + 1) / w_tgt - 1
# option-2: calculate in the correct way from ground-up
# d = torch.arange(w_tgt)
# meshx, meshy = torch.meshgrid((d, d))
# _grid = torch.stack((meshy, meshx), 2)[None]
# grid_false = ((_grid + 1/2) / w_tgt) * 2 - 1
output_gs_false = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False)
output_gs_false_c = torch.nn.functional.grid_sample(input, grid_false, mode='bilinear', align_corners=False)
output_itp_true = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=True)
output_itp_false = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=False)
output_gs_false_c
and output_itp_false
are different.
Is there any operation leading to inconsistency between F.interpolate()
and F.grid_sample()
, or I calculate the grid
for sampling wrong? @bnehoran