I have a loss function that uses torch.angle(), but the argument can take any complex value, therefore even if I do some trick like
arg + 1e-7 or something along those lines, technically nothing prevents the argument of being exactly the value to produce a
nan during the backward pass.
What would be a clean solution to such problem? Can I just plug in torch.nan_to_num() and expect the backward pass to be happy, or should I switch to something like torch.atan() and manually add
+1e-7 to the denominator of the fraction and to replace torch.angle() by this? The
__call__ of this class is the one that is causing me issues:
Edit I think it is solved now and the issue were denormals that for some reason I forgot they existed. I just set a threshold that corresponds to the same eps that tensorflow uses and seems to be fine, however I found this:
but it mentions that it is hardware-dependent. What would be a good cross-platform manner to manage this to avoid similar issues? or is it safe to assume that it will be present in most of updated machines?
class ComplexCompressedMSELoss: def __init__(self, c_: float = 0.3, lambda_: float = 0.3, eps: float = 1e-7): super().__init__() self.c_ = c_ self.lambda_ = lambda_ self.eps = eps def __call__(self, y_pred_mask, x_complex, y_complex): # get target magnitude and phase y_mag = torch.abs(y_complex) y_phase = torch.angle(y_complex) # predicted complex stft y_pred_mask = y_pred_mask.squeeze(1).permute(0, 2, 1) y_pred_complex = y_pred_mask.type(torch.complex64) * x_complex # get predicted magnitude and phase y_pred_mag = torch.abs(y_pred_complex) y_pred_phase = torch.angle(y_pred_complex) # target complex exponential y_complex_exp = (y_mag ** self.c_).type(torch.complex64) * \ torch.exp(1j * y_phase.type(torch.complex64)) # predicted complex exponential y_pred_complex_exp = (y_pred_mag ** self.c_).type(torch.complex64) * \ torch.exp(1j * y_pred_phase.type(torch.complex64)) # magnitude only loss component mag_loss = torch.abs(y_mag ** self.c_ - y_pred_mag ** self.c_) ** 2 mag_loss = torch.sum(mag_loss, dim=[1, 2]) # complex loss component complex_loss = torch.abs(y_complex_exp - y_pred_complex_exp) ** 2 complex_loss = torch.sum(complex_loss, dim=[1, 2]) # blend both loss components loss = (1 - self.lambda_) * mag_loss + (self.lambda_) * complex_loss # returns the mean blended loss of the batch return torch.mean(loss)