Differentiable indexing: grid states where to map pixel

I’m looking to perform the following operation in a differentiable manner:
In: dim N x C x H_in x W_in
grid: dim N x H_in x H_out x 2

Out: dim N x C x H_out x W_out

(N is batch_size, C is number of channels)


Out = torch.zeros((N, C, H_out, W_out))
for i in range(H_in):
    for j in range(W_in):
        for n in range(N):
            Out[n, :, grid[n, i, j, 0], grid[n, i, j, 1]] += In[n, :, i, j]
return Out

Basically grid[n, i, j, :] states where it maps input feature to output.
I know in F.grid_sample, grid[n, i, j, :] where to map out from.