Change of 3D feature-map/ image orientation after call F.grid_sample with identity transformation


I would like to use torch.nn.functional.grid_sample to resample a 3D image volume, so I tested it with identity displacements (see the code below). This should not spatially transform the image; however, it did re-orient the image. The attached figure shows the image after resampling using F.grid_sample on the first row. The second row shows the original image.

Would there be something that I missed or I did wrong? Any suggestions on how to fix this.

def coordinates_map(size, start=-1.,end=1.):
    '''This is a function to create a 3D grid '''
    batch_size, channels = size[:2]
    d,h,w = size[2:]

    w_p = np.linspace(start, end, num=w)
    h_p = np.linspace(start, end, num=h)
    d_p = np.linspace(start, end, num=d)
    coords = np.stack(np.meshgrid(h_p,d_p,w_p, indexing='xy'), axis=-1)[np.newaxis,...]

    return torch.from_numpy(coords).float().expand((batch_size,)+coords.shape[1:])

image = torch.rand((1, 1, 48, 65, 64)) # I used a real image not random tensor
grid = coordinates_map(size=image.shape)
deformed_image = F.grid_sample(x,grid) 

Hey Sureerat, Not sure if you’ve already solved this, but I was just working on the exact same problem today. There is just one more step at the end that you could implement. The grid sample changes the orientation of the scan that you input so you use .permute() on your output tensor to reorient it.

I found to get the proper resampling I had to rearrange just the last three dimensions

>d1 = torch.linspace(-1, 1, self.shape[2])
> d2 = torch.linspace(-1, 1, self.shape[1])
> d3 = torch.linspace(-1, 1, self.shape[0])
> meshx, meshy, meshz = torch.meshgrid((d1, d2, d3))
> grid = torch.stack((meshx, meshy, meshz), 3)
> grid = grid.unsqueeze(0)  # add batch dim
> out = torch.nn.functional.grid_sample(x, grid)
> out = out.permute(0,1,4,3,2)