Fourier transform and complex dtype restrictions

Hi,

Many operations do not seem to be implemented for the dtype=complex yet. I am currently facing this problem:

import numpy as np
import torch
import torch.fft

x = torch.tensor(np.random.random((30, 20, 10)), requires_grad=True)
t = torch.fft.rfft(x, dim=2)
t = torch.prod(t, dim=1)  # or torch.exp(torch.sum(torch.log(t), dim=1))
xn = torch.fft.irfft(t, dim=1)

which gives RuntimeError: prod does not support automatic differentiation for outputs with complex dtype. Using torch.log results in similar problems.

Do you guys have any idea how to get around this problem for the above situation?

Thank you in advance.

Hi,

The issue here is not with the fft but the torch.prod operation that does not support autograd for complex inputs at the moment.
Complex autograd is work in progress but you can open a feature request on github if you need it now.

Thank you for your reply. I should have been more clear in formulating the problem.

I’m well aware that the problem is with prod, I only included the fft part to show the context – perhaps someone smart had an idea how to get around the problem knowing the context. I have a hunch that the complex dtype could be avoided entirely. Perhaps.

Anyway, thank you. Will open the issue.

1 Like

You might want to try to do the multiplication by hand as I think the regular multiplication is already supported:

# replace t = torch.prod(t, dim=1) by
dim = 1
res = t.select(dim, 0)
for i in range(1, t.size(dim)):
  res = res * t.select(dim, i)
t = res

Great idea! Didn’t think about that these would be separate implementations.

Will try tomorrow. Thank you.

Works like a charm! Thank you again.

1 Like