I am currently struggling to use grid_sample the right way and need some help.
My problem (simplified) is the following
I have multiple 3D volume of the shape [3,3,3]. All values are 0, except for the value at [0,0,0], which is 1.
In total, I have 5 batches. The resulting shape is
According to torch.nn.functional.grid_sample — PyTorch 2.0 documentation I add the channel with unsqueeze(1), resulting in the shape [5,1,3,3,3]
I also have a flow field of the shape [3,3,3,3], where the last dimension indicates the movement into x/y/z direction. Again, 5 batches, so the resulting shape is [5,3,3,3,3].
How can I use grid_sample in order to apply the flow and generate a warped image where the value 1 is moved into another field? What shape is expected for the grid?
Right now, I don’t undestand how
warped = torch.nn.functional.grid_sample(base, flow, mode='bilinear', padding_mode='zeros', align_corners=True) moves the values of base