How to use torch fft for color images

How to use torch.fft for a batch containing a number (52 here) of 2D RGB images.

imgs.shape
torch.Size([52, 3, 128, 128])

Thanks

So the Fourier transform works on intensities and an RGB image won’t have a defined intensity given a pixel. You could try and splitting the image in the rgb channels and then running torch.fft on them?

image = torch.randn([52,3,128,128])
r =torch.rfft(image[:,0:1].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])
g =torch.rfft(image[:,1:2].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])
b =torch.rfft(image[:,2:3].permute([0,2,3,1]), signal_ndim=1).permute([0,3,1,2,4])

image_fft = torch.cat([r,g,b], dim=1)
image_fft.shape
>>torch.Size([52, 3, 128, 128, 2])

Apologies, what I wanted is fft2.

Hi why is the signal_ndim=1? Instead of 2?

For anyone using, I believe this is a more elegant implementation

torch.rfft(imgs, signal_ndim=2, normalized=True)

As torch.rfft() should be running from the last dimension, meaning that anything before imgs[2] will be considered as a batch size. Hence [52, 3] is treated as a whole and the fft is performed only at [128, 128].