NaN gradient for torch.cos() / torch.acos()

When you do backprogation with the first, at some point you’ll run into the derivative of acos(x), which is - 1 / sqrt( 1 - x^2 ). That can be nasty and lead to your NaNs if x is close to 1 or -1 at times.

In particular, consider the following two functions: f(x) = cos(acos(x)) and g(x) = x. They’re almost equivalent (except for when x = 1, -1). When one needs to backprop against g(x), life is easy: for some operation z on the output y = g(x), the chain rule gives you dz/dy * dy/dx = dz/dy.

On the other hand, with y = f(x), the backpropagation looks like:
dz/dy * dy/dx = dz/dy * (- sin (acos (x) ) (- 1/ sqrt(1 - x^2))
If x is close to 1 or -1, this could be very bad.

3 Likes