Speed up FFT Convolution Layer

Hello,

FFT Convolutions should theoretically be faster than linear convolution past a certain size. Since pytorch has added FFT in version 0.40 + I’ve decided to attempt to implement FFT convolution.

It is quite a bit slower than the implemented torch.nn.functional.conv2d()

FFT Conv Ele GPU Time: 4.759008884429932
FFT Conv Pruned GPU Time: 5.33543848991394
Functional Conv GPU Time: 0.07413554191589355

Using the same sizes for the CPU tests:

FFT Conv CPU Time: 66.0956494808197
Functional Conv CPU Time: 3.2627475261688232

I’ve tried doing the operation as a pruned FFT (http://www.fftw.org/pruned.html), but this made things worse.

A recent paper (2017): http://ispass.org/ispass2017/slides/kim_cnn_gpu.pdf from slide 36 on shows that FFT convolution should be faster VS linear convolution using GPU’s.

This leaves me with a few questions and requests:

Is there something I am doing wrong that there is a massive speed difference?

Is it possible that I can switch the element wise multiplication + sum over in channels into a matmul operation? (I have rough code for this but it gives incorrect results, it is also still slower then linear conv)

How is convolution (technically autocorrelation Functional Conv2d Produces Different Results VS Scipy Convolved2d) implemented in pytorch? Winograd or direct?

How is FFT implemented in pytorch? Does it use cuFFT as a backend? If it is cuFFT is it possible there is a lot of overhead in creating plans (https://docs.nvidia.com/cuda/cufft/index.html) each call to fft?

I would appreciate any tips to speed up the code.

Thanks

import torch
import numpy as np
import torch.nn.functional as F
from scipy import signal
import scipy
import time

#######################################################

def dft_conv(imgR,imgIm,kernelR,kernelIm):

    # Fast complex multiplication
    ac = torch.mul(kernelR, imgR)
    bd = torch.mul(kernelIm, imgIm)
    
    ab_cd = torch.mul(torch.add(kernelR, kernelIm), torch.add(imgR, imgIm))
    # print(ab_cd.sum(1)[0,0,:,:])
    imgsR = ac - bd
    imgsIm = ab_cd - ac - bd

    # Sum over in channels
    imgsR = imgsR.sum(1)
    imgsIm = imgsIm.sum(1)

    return imgsR,imgsIm

def prepForTorch_FromNumpy(img):

    # Add batch dim, channels, and last dim to concat real and imag
    img = np.expand_dims(img, 0)
    img = np.vstack((img, np.imag(img)))
    img = np.transpose(img, (1, 2, 0))

    # Add dimensions
    img = np.expand_dims(img, 0)
    img = np.expand_dims(img, 0)
    img = np.expand_dims(img, 0)

    return img

class FFT_Conv_Layer(torch.nn.Module):

    def __init__(self,filts,imgSize,filtSize=3,cuda=False):

        super(FFT_Conv_Layer, self).__init__()

        if cuda:
            self.filts = torch.from_numpy(filts).type(torch.float32).cuda()
        else:
            self.filts = torch.from_numpy(filts).type(torch.float32)

        self.imgSize = imgSize
        self.filtSize = filtSize

    def forward(self,imgs):

        # Pad and transform the image
        # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..)
        imgs = F.pad(imgs, (0, 0, 0, self.filtSize - 1, 0,self.filtSize - 1))

        imgs = torch.fft(imgs,2)

        # Extract the real and imaginary parts
        imgsR = imgs[:, :, :, :, :, 0]
        imgsIm = imgs[:, :, :, :, :, 1]

        # Pad and transform the filters
        filts = F.pad(self.filts, (0, 0, 0, self.imgSize - 1, 0, self.imgSize - 1))

        filts = torch.fft(filts, 2)

        # Extract the real and imaginary parts
        filtR = filts[:, :, :, :, :, 0]
        filtIm = filts[:, :, :, :, :, 1]

        # Do element wise complex multiplication
        imgsR, imgsIm = dft_conv(imgsR,imgsIm,filtR,filtIm)

        # Add dim to concat over
        imgsR = imgsR.unsqueeze(4)
        imgsIm = imgsIm.unsqueeze(4)

        # Concat the real and imaginary again then IFFT
        imgs = torch.cat((imgsR,imgsIm),-1)
        imgs = torch.ifft(imgs,2)

        # Filter and imgs were real so imag should be ~0
        imgs = imgs[:,:,1:-1,1:-1,0]

        return imgs


class FFTpruned_Conv_Layer(torch.nn.Module):

    def __init__(self,filts,imgSize,filtSize=3,cuda=False):

        super(FFTpruned_Conv_Layer, self).__init__()

        if cuda:
            self.filts = torch.from_numpy(filts).type(torch.float32).cuda()
        else:
            self.filts = torch.from_numpy(filts).type(torch.float32)

        self.imgSize = imgSize
        self.filtSize = filtSize

    def forward(self,imgs):

        # Pad and transform the image
        # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..)

        imgs = F.pad(imgs, (0, 0, 0, self.filtSize - 1, 0,self.filtSize - 1))
        imgs = torch.fft(imgs,2)

        # Extract the real and imaginary parts
        imgsR = imgs[:, :, :, :, :, 0]
        imgsIm = imgs[:, :, :, :, :, 1]

        # Pad and transform the filters
        # Pruned version, 2D FFT is FFT of rows then FFT of cols
        # So only do the first filt size FFTs for the first pass
        # The all columns for the rest
        filts = F.pad(self.filts, (0, 0, 0, self.imgSize - 1, 0, 0))
        filts = torch.fft(filts,1)

        filts = torch.transpose(filts,3,4)
        filts = F.pad(filts,(0,0,0,self.imgSize - 1,0,0))
        filts = torch.fft(filts,1)
        filts = torch.transpose(filts,3,4)

        # Extract the real and imaginary parts
        filtR = filts[:, :, :, :, :, 0]
        filtIm = filts[:, :, :, :, :, 1]

        # Do element wise complex multiplication
        imgsR, imgsIm = dft_conv(imgsR,imgsIm,filtR,filtIm)

        # Add dim to concat over
        imgsR = imgsR.unsqueeze(4)
        imgsIm = imgsIm.unsqueeze(4)

        # Concat the real and imaginary again then IFFT
        imgs = torch.cat((imgsR,imgsIm),-1)
        imgs = torch.ifft(imgs,2)

        # Filter and imgs were real so imag should be ~0
        imgs = imgs[:,:,1:-1,1:-1,0]

        return imgs


def initialTest():
    imgSize = 5
    inCs = 1
    outCs = 1

    testImg = np.array([[1.0,2,3,4,5],[4,5,6,7,8],[7,8,9,10,11],[11,12,13,14,15],[16,17,18,19,20]])
    testFilt = np.array([[1,2,5],[3.0,4,2],[7,8,9]])

    # Numpy test
    npConv = scipy.signal.convolve2d(testImg,testFilt,mode='same')

    # Make arrays into proper torch size (BS,InC,OutC,ImgH,ImgW,2 -> Real | Complex)
    img = prepForTorch_FromNumpy(testImg)
    filt = prepForTorch_FromNumpy(testFilt)

    img = torch.from_numpy(img).type(torch.float32)

    fftConv = FFT_Conv_Layer(filt,imgSize)
    fftOut = fftConv(img)

    fftPruned = FFTpruned_Conv_Layer(filt,imgSize)
    fftP_Out = fftPruned(img)

    # Only need real part for conv2d
    img = img[:,:,0,:,:,0]
    filt = filt[:,:,0,:,:,0]

    filt = torch.from_numpy(filt).type(torch.float32)

    # Padding pads on both sides symmetrically
    # Doesn't match scipy, this does auto correlation NOT convolution
    funOut = F.conv2d(img, filt,bias=None,padding=1,stride=(1,1))

    print(npConv)
    print(fftOut)
    print(fftP_Out)
    print(funOut)

def largerTestCPU():

    filtSize = 3
    inCs = 3
    outCs = 32
    batchSize = 100
    imgSize = 16
    imagDim = 2

    imgs = torch.randn(batchSize,inCs,1,imgSize, imgSize,imagDim)
    filts = np.random.normal(size=(1,inCs,outCs,filtSize,filtSize,imagDim))

    fftConv = FFT_Conv_Layer(filts, imgSize)

    st = time.time()
    for i in range(50):
        fftOut = fftConv(imgs)
    et = time.time()
    print("FFT Conv CPU Time: {}".format(et - st))

    filts = torch.from_numpy(filts).type(torch.float32)
    filts = torch.transpose(filts,1,2)

    imgs = imgs.squeeze(2)
    filts = filts.squeeze(0)
    imgs = imgs[:,:,:,:,0]
    filts = filts[:,:,:,:,0]

    st = time.time()
    for i in range(50):
        funOut = F.conv2d(imgs, filts, bias=None, padding=1)
    et = time.time()
    print("Functional Conv CPU Time: {}".format(et - st))

def largerTestGPU():

    filtSize = 3
    inCs = 16
    outCs = 32
    batchSize = 100
    imgSize = 64
    imagDim = 2
    numIters = 50

    imgs = torch.randn(batchSize,inCs,1,imgSize, imgSize,imagDim).cuda()
    filts = np.random.normal(size=(1,inCs,outCs,filtSize,filtSize,imagDim))

    fftConv = FFT_Conv_Layer(filts, imgSize,cuda=True)

    # GPU warm up time
    for i in range(2):
        fftOut = fftConv(imgs)

    # Element wise
    torch.cuda.synchronize()
    st = time.time()
    for i in range(numIters):
        fftOut = fftConv(imgs)
    torch.cuda.synchronize()
    et = time.time()
    print("FFT Conv Ele GPU Time: {}".format(et - st))

    fftPruned = FFTpruned_Conv_Layer(filts, imgSize, cuda=True)

    # Pruned FFT
    torch.cuda.synchronize()
    st = time.time()
    for i in range(numIters):
        fftOut = fftPruned(imgs)
    torch.cuda.synchronize()
    et = time.time()
    print("FFT Conv Pruned GPU Time: {}".format(et - st))

    filts = torch.from_numpy(filts).type(torch.float32).cuda()
    filts = torch.transpose(filts,1,2)

    imgs = imgs.squeeze(2)
    filts = filts.squeeze(0)
    imgs = imgs[:,:,:,:,0]
    filts = filts[:,:,:,:,0]

    # Functional Conv
    torch.cuda.synchronize()
    st = time.time()
    for i in range(numIters):
        funOut = F.conv2d(imgs, filts, bias=None, padding=1)
    torch.cuda.synchronize()
    et = time.time()
    print("Functional Conv GPU Time: {}".format(et - st))

# initialTest()
largerTestCPU()
largerTestGPU()
1 Like

Hi,

I don’t have a specific knowledge on this, but the timing code you shared is not correct for cuda timings:
You need to add a torch.cuda.synchronize() before the second time.time() otherwise you only measure the time to launch the compute on the GPU, not the full execution time.

Good to know, I’ve edited the initial post. The timings are now unfortunately much worse:

FFT Conv Ele GPU Time: 4.759008884429932
FFT Conv Pruned GPU Time: 5.33543848991394
Functional Conv GPU Time: 0.07413554191589355

Using the same sizes for the CPU tests:

FFT Conv CPU Time: 66.0956494808197
Functional Conv CPU Time: 3.2627475261688232

I profiled the code using the torch.autograd.profiler.profile(use_cuda=True).

Torch.mul and sub take the large majority of the time. I find this odd considering the FFT should be the most expensive process.

mul Total Cuda Time: 13374.464us
sub Total Cuda Time: 12881.920us
fft Total Cuda Time: 377.856us

A single multiplication takes more time than the entire 2dconv. Is there any reason for this?

FFT Conv

------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       CPU time        CUDA time            Calls        CPU total       CUDA total
------------------  ---------------  ---------------  ---------------  ---------------  ---------------
ConstantPadNd             137.393us        912.160us                1        137.393us        912.160us
tensor                     13.192us         12.704us                1         13.192us         12.704us
fill_                      20.303us        271.840us                1         20.303us        271.840us
narrow                     10.064us          3.072us                1         10.064us          3.072us
as_strided                  4.841us          1.024us                1          4.841us          1.024us
narrow                      6.956us          3.072us                1          6.956us          3.072us
as_strided                  3.490us          1.024us                1          3.490us          1.024us
fft                       366.340us       1125.376us                1        366.340us       1125.376us
reshape                     6.930us          3.072us                1          6.930us          3.072us
as_strided                  2.757us          1.024us                1          2.757us          1.024us
_fft_with_size            341.994us       1115.136us                1        341.994us       1115.136us
reshape                     7.154us          3.072us                1          7.154us          3.072us
as_strided                  2.890us          1.024us                1          2.890us          1.024us
select                      6.358us          3.072us                1          6.358us          3.072us
as_strided                  2.261us          1.024us                1          2.261us          1.024us
select                      5.708us          3.072us                1          5.708us          3.072us
as_strided                  2.252us          1.024us                1          2.252us          1.024us
ConstantPadNd              69.969us         98.304us                1         69.969us         98.304us
tensor                      3.763us          1.024us                1          3.763us          1.024us
fill_                       8.095us         81.952us                1          8.095us         81.952us
narrow                      6.727us          3.072us                1          6.727us          3.072us
as_strided                  2.766us          1.024us                1          2.766us          1.024us
narrow                      6.890us          3.072us                1          6.890us          3.072us
as_strided                  3.516us          1.024us                1          3.516us          1.024us
fft                       327.536us        377.856us                1        327.536us        377.856us
reshape                     6.212us          3.072us                1          6.212us          3.072us
as_strided                  2.527us          1.024us                1          2.527us          1.024us
_fft_with_size            305.552us        367.616us                1        305.552us        367.616us
reshape                     6.970us          3.072us                1          6.970us          3.072us
as_strided                  2.821us          1.024us                1          2.821us          1.024us
select                      6.008us          3.072us                1          6.008us          3.072us
as_strided                  2.276us          1.024us                1          2.276us          1.024us
select                      6.037us          3.072us                1          6.037us          3.072us
as_strided                  2.459us          1.024us                1          2.459us          1.024us
expand                      3.891us          0.000us                1          3.891us          0.000us
expand                      2.262us          1.024us                1          2.262us          1.024us
mul                        13.018us      13373.440us                1         13.018us      13373.440us
expand                      2.868us          1.024us                1          2.868us          1.024us
expand                      3.245us          2.048us                1          3.245us          2.048us
mul                         7.824us      13374.464us                1          7.824us      13374.464us
add                        10.083us        130.049us                1         10.083us        130.049us
add                         7.389us        405.502us                1          7.389us        405.502us
expand                      2.945us          1.024us                1          2.945us          1.024us
expand                      2.206us          1.024us                1          2.206us          1.024us
mul                         7.154us      10663.937us                1          7.154us      10663.937us
sub                         9.580us      12881.920us                1          9.580us      12881.920us
sub                         6.989us      12941.307us                1          6.989us      12941.307us
sub                         6.710us      12944.382us                1          6.710us      12944.382us
sum                        17.342us       4509.697us                1         17.342us       4509.697us
_sum                       12.580us       4507.652us                1         12.580us       4507.652us
sum                        12.929us       4544.518us                1         12.929us       4544.518us
_sum                        7.680us       4542.458us                1          7.680us       4542.458us
unsqueeze                   4.507us          1.022us                1          4.507us          1.022us
unsqueeze                   2.375us          1.022us                1          2.375us          1.022us
cat                        15.456us       1828.865us                1         15.456us       1828.865us
ifft                      332.527us       3326.973us                1        332.527us       3326.973us
reshape                     7.441us          3.075us                1          7.441us          3.075us
as_strided                  3.722us          1.022us                1          3.722us          1.022us
_fft_with_size            308.962us       3316.734us                1        308.962us       3316.734us
reshape                     7.042us          3.075us                1          7.042us          3.075us
as_strided                  2.849us          1.030us                1          2.849us          1.030us
slice                       6.009us          3.075us                1          6.009us          3.075us
as_strided                  2.182us          1.022us                1          2.182us          1.022us
slice                       5.469us          3.075us                1          5.469us          3.075us
as_strided                  2.191us          1.030us                1          2.191us          1.030us
select                      5.729us          3.075us                1          5.729us          3.075us
as_strided                  2.198us          1.030us                1          2.198us          1.030us

F.conv2d
---------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                          CPU time        CUDA time            Calls        CPU total       CUDA total
---------------------  ---------------  ---------------  ---------------  ---------------  ---------------
conv2d                        62.204us       1486.848us                1         62.204us       1486.848us
convolution                   58.547us       1483.776us                1         58.547us       1483.776us
_convolution                  55.039us       1482.272us                1         55.039us       1482.272us
clone                         11.581us        390.144us                1         11.581us        390.144us
tensor                         3.375us          1.024us                1          3.375us          1.024us
cudnn_convolution             29.480us       1085.440us                1         29.480us       1085.440us

Recent cudnn library (pytorch is based on) already has fast conv2d algorithms in it like

  1. Conv as matrix multiplication (GEMM) based on (I think) shtrassen fast matrix multiplication
  2. Winograd fast conv2d
  3. FFT Conv2d
    So I think pytorch itself select optimal algorithm for you and you do not need to do this at all
3 Likes

I’ve adopted your code and did some test. Your problem is that you are using too small data to take the fft an effect. I need to do big convolutions i. e. 30x30 px over “full HD” images 1920x1920. In this setup your implementations brings some fruit over when comapring your fft convolution ran on GPU and basic numpy fftconvolution.

kernel 30x30 img 1920x1920
‘fft_conv2d_numpy_test’ 52.11 ms 36.56 ms 42.81 ms 42.81 ms
‘conv2d_torch_cpu_all’ 429.19 ms 415.87 ms 422.51 ms 416.33 ms
‘conv2d_torch_gpu_all’ 1910.89 ms 111.67 ms 111.00 ms 135.16 ms
‘fft_conv2d_torch_cpu_all’ 282.97 ms 250.37 ms 255.89 ms 266.99 ms
‘fft_conv2d_torch_gpu_all’ 224.72 ms 22.27 ms 20.70 ms 21.17 ms

1 Like

Hello, may i know for

   imgs = torch.randn(batchSize,inCs,1,imgSize, imgSize,imagDim)

how do we set this imgsR, imgsIm from
imagDim. Any example of real batch image NCHW ?

Is there any update on this? Were you able to get conv using FFT faster than the standard torch conv?

Would you be able to share the code used for this?