I have a loss function that uses torch.angle(), but the argument can take any complex value, therefore even if I do some trick like arg + 1e-7
or something along those lines, technically nothing prevents the argument of being exactly the value to produce a nan
during the backward pass.
What would be a clean solution to such problem? Can I just plug in torch.nan_to_num() and expect the backward pass to be happy, or should I switch to something like torch.atan() and manually add +1e-7
to the denominator of the fraction and to replace torch.angle() by this? The __call__
of this class is the one that is causing me issues:
Edit I think it is solved now and the issue were denormals that for some reason I forgot they existed. I just set a threshold that corresponds to the same eps that tensorflow uses and seems to be fine, however I found this:
https://pytorch.org/docs/stable/generated/torch.set_flush_denormal.html
but it mentions that it is hardware-dependent. What would be a good cross-platform manner to manage this to avoid similar issues? or is it safe to assume that it will be present in most of updated machines?
class ComplexCompressedMSELoss:
def __init__(self,
c_: float = 0.3,
lambda_: float = 0.3,
eps: float = 1e-7):
super().__init__()
self.c_ = c_
self.lambda_ = lambda_
self.eps = eps
def __call__(self, y_pred_mask, x_complex, y_complex):
# get target magnitude and phase
y_mag = torch.abs(y_complex)
y_phase = torch.angle(y_complex)
# predicted complex stft
y_pred_mask = y_pred_mask.squeeze(1).permute(0, 2, 1)
y_pred_complex = y_pred_mask.type(torch.complex64) * x_complex
# get predicted magnitude and phase
y_pred_mag = torch.abs(y_pred_complex)
y_pred_phase = torch.angle(y_pred_complex)
# target complex exponential
y_complex_exp = (y_mag ** self.c_).type(torch.complex64) * \
torch.exp(1j * y_phase.type(torch.complex64))
# predicted complex exponential
y_pred_complex_exp = (y_pred_mag ** self.c_).type(torch.complex64) * \
torch.exp(1j * y_pred_phase.type(torch.complex64))
# magnitude only loss component
mag_loss = torch.abs(y_mag ** self.c_ - y_pred_mag ** self.c_) ** 2
mag_loss = torch.sum(mag_loss, dim=[1, 2])
# complex loss component
complex_loss = torch.abs(y_complex_exp - y_pred_complex_exp) ** 2
complex_loss = torch.sum(complex_loss, dim=[1, 2])
# blend both loss components
loss = (1 - self.lambda_) * mag_loss + (self.lambda_) * complex_loss
# returns the mean blended loss of the batch
return torch.mean(loss)