How does grid_sample(x, grid) work?

Oh, okay, so here is the deal:
PyTorch actually currently has 3 different underlying implementations of grid_sample() (a vectorized cpu 2D version, a nonvectorized cpu 3D version, and a CUDA implementation for both 2D and 3D), but their behavior is essentially supposed to be the same.

In my opinion, the easiest of the three to understand, if you just want to get the basic idea, is the CUDA version, which you can find here for the 2D case.

The important lines in terms of zero padding are these, which implicitly perform the zero padding by calling the within_bounds_2d() to only add each term to the bilinear interpolation if it is in bounds. Any out-of-bounds grid points will get 0, of course, since nothing will then be added to the 0 from line 198.

Note that this out-of-bounds check does not affect the border and reflection padding modes, since in those cases, the grid points will have previously been brought in-bounds by the clip and reflect operations here.

Now, if you want to implement a similar zero padding behavior in TensorFlow, here’s how I would do it:
Take a look at this _interpolate() function. First of all, I should note that this function is not quite the same as the PyTorch grid_sample() in two ways:

  1. It is not meant to be called externally, only as part of transformer(), and so it actually takes the grid as two flattened tensors x and y. Of course, if you follow the code on these lines, you can figure out how to reformat your grid this way. (No need to multiply by an affine matrix if you already have the grid you want to use. This just produces an affine grid.)

  2. It miscalculates the conversion from the [-1,+1] of the grid to pixel indices, and is in fact off by half a pixel. If half a pixel doesn’t bother you, then great. If it does, fixing it is a bit more involved, but also possible.

So supposing you can get over these two hurdles, how do you implement zero padding? First thing you have to do is remove the coordinate clipping on these lines, since you want to know when a grid point is out-of-bounds. Second, you have to add a check for out-of-bounds grid points and either zero out their values here or their interpolation weights here.

That should be more than enough if you’re running this on GPU. If you’re running it on CPU, then you need to be careful here that you’re not going out of bound in the flat 1D tensor. That’s because otherwise, you will probably get an out-of-bounds error.

Hopefully this is enough to get you where you need.

2 Likes