Help with Fourier transform and vision CNNs


Recently I started working on some research project at school. I try to do a quick summary.
Input image = I
Kernel = K

Apply Fourier Transform on the Image I. Filter out the 50% highest frequencies. Then do conv in the F domain and pass it through a smaller CNN. Then finally do Inverse FT and get results.

Can somebody please help me how to implement this in PyTorch (I am a mediocre w python and pytorch). I am stuck at the step after doing the conv in the freq domain(dotp). How shall I continue with the process in the freq domain and how can I use a smaller network(I would use an input image of 224x224 lets say, and then in the Freq domain I could use a smaller resnet for Magnitude only images of around 50x50) ? To have benefits of filtering out the high frequencies.

Thanks in advance.