How to optimize pytorch CNN with complex-numbers (no gradient used)?

Hi,
A basic recap of what we are doing. My team and I are building a Cellular Automata program (effectively Convolutional Networks) in PyTorch. Because we are not training them via backprop, we don’t need or use the gradients; they are fully disabled. Our kernels are quite large, so we use the FFT trick to speed up computation (we have tested it, and it is indeed faster than standard convolution). Our “image” sizes are in the range of H, W = 500 -100, C = 1-3, and Batch = 1- 100.

After the convolution, we pass the results to a custom forward function. The tensors the convolution produces are of size (Batch, C, C, H, W), and they are multiplied with tensors that are representations of Fourier series with a number of harmonics Num_harmonics in the shape of (Batch, Num_harmonics, C, C,1,1). Effectively, we apply a custom function constructed from the Fourier series to every pixel resulting from the convolution.

The custom forward function then looks like this:

growth = self.coefficients * torch.exp(1j * 2 * self.harmonics] * torch.pi * (x - self.shifts) /self.period),dim=1).real

Where, self.coefficients, sel.harmonics, self.shifts, and self.periods are batched tensors, and x is the results from the convolution that have been expanded to meet the shape of the Fourier series tensors.

The issue is that batch sizes above 16 become quite slow. We know it is the equation above that is causing the problems, as we have tested it with other equations, such as:

growth = lambda x: 2 * torch.exp(-((x - mu) ** 2 / (sigma) ** 2) / 2) - 1

where x, mu, and sigma are batched tensors.

The equation above gives results that are about five times faster. Thus, it seems like the complex numbers are causing a slowdown.

We are using an RTX 4070 and are nowhere near the memory cap (about 3/12 GB). Everything is always on the GPU; we never transfer back to the CPU. Are there any tricks we can use to speed up the computation, like Channels Last Memory Format for FFT or reshaping the tensors?

Kind regards,
Etienne