I am currently struggling to use grid_sample the right way and need some help.

My problem (simplified) is the following

I have multiple 3D volume of the shape [3,3,3]. All values are 0, except for the value at [0,0,0], which is 1.

In total, I have 5 batches. The resulting shape is

[5,3,3,3]

According to torch.nn.functional.grid_sample — PyTorch 2.0 documentation I add the channel with unsqueeze(1), resulting in the shape [5,1,3,3,3]

I also have a flow field of the shape [3,3,3,3], where the last dimension indicates the movement into x/y/z direction. Again, 5 batches, so the resulting shape is [5,3,3,3,3].

How can I use grid_sample in order to apply the flow and generate a warped image where the value 1 is moved into another field? What shape is expected for the grid?

Right now, I don’t undestand how

`warped = torch.nn.functional.grid_sample(base, flow, mode='bilinear', padding_mode='zeros', align_corners=True)`

moves the values of base