Slow backpropagation with my own FFT implementation

Hi, I just implemented my own FFT (or DFT? Honestly not sure…) using Conv1d. I’m using it for FFCN as in the paper for this project: GitHub - advimman/lama: 🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022

As the title suggests, for some reason, backpropagation is super slow using my own FFT implementation. I have used the torch.fft.rfftn implementation as with the LaMa project, and it’s fast as expected. I will obviously continue to use the torch implementation, but I really want to figure out why the backpropagation process is slow using this implementation. Note that there is nothing wrong with the forward process, it’s as fast as you’d expect, just the backward process can take up to 20 seconds.

here is the code for the fft implementation

import torch
import torch.nn as nn
import numpy as np


class FFT(nn.Module):
    """
    FFT on the last axis
    """

    def __init__(self, window_size, freeze_parameters=True):
        super(FFT, self).__init__()

        kn = np.arange(window_size)[:, None] * np.arange(window_size)
        ohm = np.exp(-2 * np.pi * 1j / window_size)
        iohm = np.exp(2 * np.pi * 1j / window_size)

        w = np.power(ohm, kn)
        iw = np.power(iohm, kn) / window_size

        self.conv_real = nn.Conv1d(1, window_size, window_size, 1, 0, bias=False)
        self.conv_imag = nn.Conv1d(1, window_size, window_size, 1, 0, bias=False)
        self.iconv_real = nn.Conv1d(1, window_size, window_size, 1, 0, bias=False)
        self.iconv_imag = nn.Conv1d(1, window_size, window_size, 1, 0, bias=False)

        self.conv_real.weight.data = torch.Tensor(w.real[:, None, :])
        self.conv_imag.weight.data = torch.Tensor(w.imag[:, None, :])
        self.iconv_real.weight.data = torch.Tensor(iw.real[:, None, :])
        self.iconv_imag.weight.data = torch.Tensor(iw.imag[:, None, :])

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, x_r, x_i, inverse=False):
        if x_i is None:
            x_i = torch.zeros_like(x_r)
        conv_real = self.iconv_real if inverse else self.conv_real
        conv_imag = self.iconv_imag if inverse else self.conv_imag

        shape = list(x_r.shape)
        _last = shape[-1]
        _batch = np.prod(shape[:-1])

        x_r = x_r.reshape(_batch, 1, _last)
        x_i = x_i.reshape(_batch, 1, _last)

        real = conv_real(x_r) - conv_imag(x_i)
        imag = conv_imag(x_r) + conv_real(x_i)
        # (_batch, chout, 1)

        real = real.reshape(shape)
        imag = imag.reshape(shape)
        return real, imag

The usage is exactly as the LaMa paper describes. Local and global context, so this fft function is called for every iteration of the global to global block. But even just having very few of it (4 total times fft/ifft) still takes my 3090 half a second for the backpropagation process.

Thanks for the help! I would really like to understand backpropagation more.