Hello, I’m trying to explicitly write down an equivalent to torch.nn.functional.grid_sample()
the reason is that I need to build a TensorRT engine that executes it but the onnx node GridSample
is not yet implemented in trt.
To start, I’ve tried re-implementing the c algorithm in python:
def inside(x, y, H, W):
return 0 <= x < W and 0 <= y < H
def grid_sample(img, coords, mode="bilinear", padding_mode="zeros", align_corners=False):
N, C, IH, IW = img.shape
H, W = coords.shape[1:3]
output = torch.zeros([N, C, H, W])
for n in range(N):
for h in range(H):
for w in range(W):
ix = coords[n, h, w, 0]
iy = coords[n, h, w, 1]
ix = ((ix + 1) / 2) * (IW - 1)
iy = ((iy + 1) / 2) * (IH - 1)
ix_nw = floor(ix)
iy_nw = floor(iy)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
for c in range(C):
nw_val = img[n, c, iy_nw, ix_nw] if inside(ix_nw, iy_nw, IH, IW) else 0
ne_val = img[n, c, iy_ne, ix_ne] if inside(ix_ne, iy_ne, IH, IW) else 0
sw_val = img[n, c, iy_sw, ix_sw] if inside(ix_sw, iy_sw, IH, IW) else 0
se_val = img[n, c, iy_se, ix_se] if inside(ix_se, iy_se, IH, IW) else 0
out_val = nw_val * nw + ne_val * ne + sw_val * sw + se_val * se
output[n, c, h, w] = out_val
return output
This is the result but even executing this I get different value than the one returned by F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=False)
Why could that be?