# FFT and complex values in loss function

I am new to pytorch and trying to use it to solve an underdetermined problem where I have a limited number of samples of an FFT. For now I am using the entire FFT and just a squared error loss.

My output looks like a rightside-up image superimposed on an upside-down image: I think this means I am losing the imaginary values somewhere. I verified that all the operations (subtraction, squaring, summing) are behaving as I expect in the loss calculation. What could be causing this?

Here is my code:

``````import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.data import shepp_logan_phantom
from skimage.transform import resize

imsize = 128
image = resize(shepp_logan_phantom(),(imsize,imsize))
fftimage = np.fft.fft2(image)

x = torch.zeros((imsize,imsize)) # guess
data = torch.from_numpy(fftimage)

losses = []

def loss(x):
fftx = torch.fft.fft2(x)
loss = (data - fftx).pow(2).sum()
loss = loss.real.pow(2) + loss.imag.pow(2)
return loss

itera = 10000
for i in range(itera):
current_loss = loss(x)
current_loss.backward()
optimizer.step()
losses.append(current_loss)

reconstruction = x.detach().numpy()

fig1,ax1 = plt.subplots(1,1)
ax1.plot(losses)
ax1.set_yscale("log")
ax1.set_ylabel("loss")
ax1.set_xlabel("iterations")

fig,ax = plt.subplots(1,2)
ax.imshow(image)
ax.imshow(np.abs(reconstruction))
``````
2 Likes

Have you solve this problem? I recently on MRI reconstruction and using complex number in my loss function also have some problem. Looking forward to hearing from you 1 Like

Is this problem because you used `.pow()`?