Convolution using FFT is slower than Conv2d 🤔

Hello folks,

According to my understanding, a convolution operation requires, for a kernel of size k, input batch of size B, input channels c, output channels f, image size n by n.

$$B * f * c * k^2 * n^2$$

The convolution can also be done on the fft domain by a element wise multiplication. Disregarding the cost of the FFT operations, for simplicity, we would have (note each complex multiplication requires 4 real ones)

4 * B * f * c * n^2

ps.: don’t know why formulas in markdown here are not working

This plot show that using FFT should be faster, even when I add the cost for the FFT operations themselves.

However, when I tested it, see Gist the time used for the multiplication operations is away longer than anything else.

Also, I had to be very careful, anything change could easily cause allocation error due to insufficient memory on GPU.

Can anyone give me a hint on how Con2d is faster than the FFT version? It suppose to be the other way around. :thinking:

At least part of your performance loss is you had to reimplement complex multiplication by hand using an unfused series of operations. Fusing the operations would be a big win. You should be able to use the PyTorch JIT’s fuser to test this.

1 Like

Thanks for the feedback. I am not familiar with the term PyTorch JIT’s fuser. Is this where I should be reading from?

I just got these two useful links. I will go through it and try a second implementation.

I wrapped my functions calls inside a torch.nn.Module, and used jit

I know that my code is really messy, didn’t have still time to make it more readable neither better written. But, hopefully, is good enough to see that the timers are still in favor of the traditional Conv2d.

Conv2d Conv2dFFT Conv2dFFT (without JIT)
0.2027 0.8736 2.9177

I am using jit like this, not sure how to use fuser as you mentioned.

class Conv2d(torch.nn.Module):
  def forward(self, x, filters):
    return torch.nn.functional.conv2d(x, filters, padding=2)

a = torch.rand(128, 64, 32, 32).cuda()
conv2d_spatial = Conv2d()
conv2d_spatial_opt = torch.jit.trace(conv2d_spatial, (a, filters)


class Conv2dFFT(torch.nn.Module):
  def forward(self, x, h):
    return fft_conv(x, h)

conv2d_fft = Conv2dFFT()
conv2d_fft_opt = torch.jit.trace(conv2d_fft, (a, filters))