How to achieve Back Propagation a complex-value model?

I try to use a simple mo del to rebuild a complex-value signal, I split the real and imag to two real numbers, and the label also has two parts, the real part is labelr, and the imag part is labeli, here’s the model, based on https://github.com/wavefrontshaping/complexPyTorch/blob/master/complexPyTorch/complexLayers.py

class model(nn.Module):
def init(self):
super(model,self).init
self.conv1=ComplexConv2d(1,64,3,1,1);
self.conv2=ComplexConv2d(64,1,3,1,1);
def forward(self,xr,xi):
xr,xi=self.conv1(xr,xi);
xr,xi=self.conv2(xr,xi);
return xr,xi

And I use mse loss:

loss1=loss_fn(yr,labelr);
loss2=loss_fn(yi,labeli);
loss=loss1+loss2
loss.backward()

I don’t know if it’s the correct way to work, or I need to modify the loss function or backward function to make it fit complex-values? I don’t know where to change the code.

When it comes to loss for complex numbers, it’s a field being studied. There are several ways to approach it.

If you’re attempting to predict a complex value, for signal processing, a Cartesian approach would be to calculate the distance between two points. This would be done via:

d = ((x2 - x1)^2 + (y2 - y1)^2)^0.5

Or in code:

def cart_loss(yr, yi, labelr, labeli):
    return ((yr - labelr)**2 + (yi - labeli)**2)**0.5

Another approach that is more natural for complex values is to find the polar loss. That is the difference between the radii and the angles. This recognizes that complex values are not simply x, y coordinates. We can obtain this via:

def polar_loss(yr, yi, labelr, labeli):
    y = torch.complex(yr,yi)
    label = torch.complex(labelr, labeli)
    rad_loss = torch.abs(y.abs() - label.abs())
    angle_loss = torch.abs(torch.angle(label/y))
    return rad_loss + angle_loss

You may find it beneficial to scale the radial loss within the same range as the angle loss. This can be done with a modified version of the above:

def polar_loss(yr, yi, labelr, labeli, eps = 1e-8):
    y = torch.complex(yr,yi)
    label = torch.complex(labelr, labeli)
    rad_loss = torch.abs(y.abs() - label.abs())/((torch.max(y.abs(), label.abs()) + eps) # normalizes the value in a range of 0 to 1
    angle_loss = torch.abs(torch.angle(label/y))/math.pi # normalizes the value to a range of 0 to 1
    return (rad_loss + angle_loss)/2
1 Like