Inconsistency between upsampling with F.interpolate() and F.grid_sample()

Hi! I find upsampling with F.interpolate and F.grid_sample leads to inconsistent results.

First, I try to do the downsampling:

# try downsampling
w_src, w_tgt = 8, 4
input = torch.arange(w_src * w_src).view(1, 1, w_src, w_src).float()

# Create grid to upsample input
d = torch.linspace(-1, 1, w_tgt)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim

output_gs_true = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True)
# option-1: correct the wrong grid
grid_false = ((grid + 1) * (w_tgt - 1) + 1) / w_tgt - 1
# option-2: calculate in the correct way from ground-up
# d = torch.arange(w_tgt)
# meshx, meshy = torch.meshgrid((d, d))
# _grid = torch.stack((meshy, meshx), 2)[None]
# grid_false = ((_grid + 1/2) / w_tgt) * 2 - 1

output_gs_false = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False)
output_gs_false_c = torch.nn.functional.grid_sample(input, grid_false, mode='bilinear', align_corners=False)

output_itp_true = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=True)
output_itp_false = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=False)

output_gs_false_c and output_itp_false are consistent.

Then, I try to do the upsampling:

# try upsampling
w_src, w_tgt = 4, 8
input = torch.arange(w_src * w_src).view(1, 1, w_src, w_src).float()

# Create grid to upsample input
d = torch.linspace(-1, 1, w_tgt)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim

output_gs_true = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=True)

# option-1: correct the wrong grid
grid_false = ((grid + 1) * (w_tgt - 1) + 1) / w_tgt - 1
# option-2: calculate in the correct way from ground-up
# d = torch.arange(w_tgt)
# meshx, meshy = torch.meshgrid((d, d))
# _grid = torch.stack((meshy, meshx), 2)[None]
# grid_false = ((_grid + 1/2) / w_tgt) * 2 - 1

output_gs_false = torch.nn.functional.grid_sample(input, grid, mode='bilinear', align_corners=False)
output_gs_false_c = torch.nn.functional.grid_sample(input, grid_false, mode='bilinear', align_corners=False)

output_itp_true = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=True)
output_itp_false = F.interpolate(input, size=w_tgt, mode='bilinear', align_corners=False)

output_gs_false_c and output_itp_false are different.

Is there any operation leading to inconsistency between F.interpolate() and F.grid_sample(), or I calculate the grid for sampling wrong? @bnehoran

1 Like

Oh, I think I get it. I should use padding_mode="border" in F.grid_sample() to be consistent with F.interpolate. Then, what are the cases making interpolate() and grid_sample() different, as noted in this issue?

1 Like