I am trying to convolve several 1D signals via FFT convolution. Does Pytorch offer any ways to avoid a for loop as below to perform a multi-dimension 1D FFT / iFFT, i.e. I would like to have a batch-wise 1D FFT?
import torch
# 1D convolution (mode = full)
def fftconv1d(s1, s2):
# extract shape
nT = len(s1)
# signal length
L = 2 * nT - 1
# compute convolution in fourier space
sp1 = torch.fft.fft(s1, L)
sp2 = torch.fft.fft(s2, L)
conv = torch.real(torch.fft.ifft(sp1 * sp2))
return conv
A = torch.randn(20, 25)
B = torch.randn(20, 25)
conv = torch.randn(20, 25*2 -1)
for it in range(20):
conv[it, :] = fftconv1d(A[it, :], B[it, :])