How to use torch.fft for a batch containing a number (52 here) of 2D RGB images.
imgs.shape
torch.Size([52, 3, 128, 128])
Thanks
How to use torch.fft for a batch containing a number (52 here) of 2D RGB images.
imgs.shape
torch.Size([52, 3, 128, 128])
Thanks
So the Fourier transform works on intensities and an RGB image won’t have a defined intensity given a pixel. You could try and splitting the image in the rgb channels and then running torch.fft
on them?
image = torch.randn([52,3,128,128])
r =torch.rfft(image[:,0:1].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])
g =torch.rfft(image[:,1:2].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])
b =torch.rfft(image[:,2:3].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])
image_fft = torch.cat([r,g,b], dim=1)
image_fft.shape
>>torch.Size([52, 3, 128, 128, 2])
Apologies, what I wanted is fft2
.
Hi why is the signal_ndim=1
? Instead of 2?
For anyone using, I believe this is a more elegant implementation
torch.rfft(imgs, signal_ndim=2, normalized=True)
As torch.rfft()
should be running from the last dimension, meaning that anything before imgs[2] will be considered as a batch size. Hence [52, 3]
is treated as a whole and the fft
is performed only at [128, 128]
.