The title says it all. Zero-division in pytorch returns NaNs, while mathematically it should return infinity. In certain cases, torch.inf
would convert to normal values further on (e.g. exp(-torch.inf)
, 1/torch.inf
, etc.), while NaNs on the other hand propagate endlessly, messing up model gradients. Pytorch functions already yield correct results for infinite values, so all that’s left is to return infinities after division. Is there any way to achieve this without resorting to masks or torch.where
?
torch.nan_to_num(torch.div(numerator, denominator), nan=0.0, posinf=float('inf'), neginf=float('-inf'))
Not sure how efficient this is though. But seems like the ideal solution according to the docs
Additionally to @Soumya_Kundu’s recommendation I also see that Inf
is returned:
x = torch.tensor(1.)
y = torch.tensor(0.)
print(x / y)
# tensor(inf)
In my experience if there is infinity in the forward computation, there will be a NaN in the backward. Like what do you expect from the gradient of exp(-1/x) at x=0+ in your example?
If you have y=1/x, is better to make sure that x is positive and bounded away from zero such that subseqent operations like y**2 do not overflow and gradient is somehow bounded. E.g.
x = softplus(t)
y= 1/(x + eps)
@ilykuleshov Mathematically speaking I think this is an ambiguous case because:
with any x > 0
As you mentioned
1/0+x > 0
and in limit 0 ← x (coming from positive numbers) it goes to infinity
But with any x < 0
1/0+x < 0
and in the limit x → 0 (coming from negative numbers) it goes to -infinity
So what is the value for x=0 negative or positive infinity?
I would say Nan is is the right behavior if ,the case 1/0, is not explicitly treated by the definition of the function you use.
But if positive infinity is your intended behaviour and not the default, which @ptrblck 's example suggests, you can write your own function that checks if the denominator is 0 and returns torch.inf and otherwise call the normal torch implementation of that function or if equivalent what @Soumya_Kundu proposed.
Thank you very much! Indeed, seems that pytorch handles this correctly, my bad. Turns out I had a 0/0 situation, in which case NaNs make total sense. Solved it with torch.where, similar to what @wildfly proposed.