I don’t know if it’s the correct way to work, or I need to modify the loss function or backward function to make it fit complex-values? I don’t know where to change the code.

When it comes to loss for complex numbers, it’s a field being studied. There are several ways to approach it.

If you’re attempting to predict a complex value, for signal processing, a Cartesian approach would be to calculate the distance between two points. This would be done via:

Another approach that is more natural for complex values is to find the polar loss. That is the difference between the radii and the angles. This recognizes that complex values are not simply x, y coordinates. We can obtain this via:

You may find it beneficial to scale the radial loss within the same range as the angle loss. This can be done with a modified version of the above:

def polar_loss(yr, yi, labelr, labeli, eps = 1e-8):
y = torch.complex(yr,yi)
label = torch.complex(labelr, labeli)
rad_loss = torch.abs(y.abs() - label.abs())/((torch.max(y.abs(), label.abs()) + eps) # normalizes the value in a range of 0 to 1
angle_loss = torch.abs(torch.angle(label/y))/math.pi # normalizes the value to a range of 0 to 1
return (rad_loss + angle_loss)/2