FFT and complex values in loss function

I am new to pytorch and trying to use it to solve an underdetermined problem where I have a limited number of samples of an FFT. For now I am using the entire FFT and just a squared error loss.

My output looks like a rightside-up image superimposed on an upside-down image:

image_and_result

I think this means I am losing the imaginary values somewhere. I verified that all the operations (subtraction, squaring, summing) are behaving as I expect in the loss calculation. What could be causing this?

Here is my code:

import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.data import shepp_logan_phantom
from skimage.transform import resize

imsize = 128
image = resize(shepp_logan_phantom(),(imsize,imsize))
fftimage = np.fft.fft2(image)

x = torch.zeros((imsize,imsize)) # guess
x.requires_grad = True
optimizer = torch.optim.Adam([x], lr=5e-4)
data = torch.from_numpy(fftimage)

losses = []

def loss(x):
    fftx = torch.fft.fft2(x)
    loss = (data - fftx).pow(2).sum()
    loss = loss.real.pow(2) + loss.imag.pow(2)
    return loss

itera = 10000
for i in range(itera):
    current_loss = loss(x)
    optimizer.zero_grad()
    current_loss.backward()
    optimizer.step()
    losses.append(current_loss)

reconstruction = x.detach().numpy()

fig1,ax1 = plt.subplots(1,1)
ax1.plot(losses)
ax1.set_yscale("log")
ax1.set_ylabel("loss")
ax1.set_xlabel("iterations")

fig,ax = plt.subplots(1,2)
ax[0].imshow(image)
ax[1].imshow(np.abs(reconstruction))
2 Likes

Have you solve this problem? I recently on MRI reconstruction and using complex number in my loss function also have some problem. Looking forward to hearing from you :grinning:

1 Like

Is this problem because you used .pow()?