For a given theta, I want to rotate a 3D tensor (D, H, W) in both real space and Fourier space. The real space application is as usual. The Fourier space application is to do a 3D fft on the volume, and do grid_sample to the real part and imaginary part separately, combine the results with torch.complex(), and inverse fft back to the real space. However, I was getting totally different results for the applications in real and Fourier space. Here’s the example code snippet:
theta = np.array([[0., -1., 0., 0.], [1., 0., 0., 0.], [0., 0., 1., 0.]])
theta = torch.tensor(theta, requires_grad=False).float().unsqueeze(0)
torch.manual_seed(0)
volume = torch.rand((5,5,5))
volume_F = torch.fft.fftshift(torch.fft.fftn(volume))
d, h, w = volume.shape
b = 1
grid = F.affine_grid(theta, (b, 1, d, h, w))
volume = volume.expand((b, d, h, w)) # (B, D, H, W)
volume_F = volume_F.expand((b, d, h, w)) # (B, D, H, W)
rot_vol = F.grid_sample(
torch.unsqueeze(volume, 1),
grid,
mode='bilinear',
padding_mode="zeros",
align_corners=True).squeeze(1) # (B, D, H, W)
rot_vol_r = F.grid_sample(
torch.unsqueeze(volume_F.real.float(), 1),
grid,
mode='bilinear',
padding_mode="zeros",
align_corners=True).squeeze(1) # (B, D, H, W)/
rot_vol_i = F.grid_sample(
torch.unsqueeze(volume_F.imag.float(), 1),
grid,
mode='bilinear',
padding_mode="zeros",
align_corners=True).squeeze(1) # (B, D, H, W)6
rot_vol_F = torch.complex(rot_vol_r, rot_vol_i)
rot_vol2 = torch.fft.ifftn(torch.fft.ifftshift(rot_vol_F[0])).real
print(rot_vol[0,1:4,1:4,1:4])
print(rot_vol2[1:4,1:4,1:4])
I printed the center 3x3x3 for rotated volume in real and Fourier spaces:
tensor([[[0.3287, 0.7239, 0.3600],
[0.2875, 0.7846, 0.6296],
[0.6597, 0.4023, 0.3662]],
[[0.2285, 0.5616, 0.7294],
[0.1306, 0.2038, 0.5334],
[0.5179, 0.6016, 0.5511]],
[[0.2563, 0.5988, 0.7047],
[0.6328, 0.2577, 0.6991],
[0.7618, 0.4310, 0.6676]]])
tensor([[[0.3730, 0.6217, 0.2595],
[0.3322, 0.3079, 0.1674],
[0.2382, 0.4306, 0.2531]],
[[0.3308, 0.3927, 0.4002],
[0.1625, 0.2234, 0.2055],
[0.0988, 0.1105, 0.1756]],
[[0.2901, 0.1826, 0.1287],
[0.1781, 0.1976, 0.2110],
[0.2818, 0.1371, 0.2022]]])
The results are completely different.
My question is, what could be wrong here? Thanks a lot in advance for any opinions!