Dear everyone,
I’m seing NaN errors in the gradient calculations of my loss function when using “jit.script” on my loss function.
I’ve narrowed the problem down to a sqrt :
@torch.jit.script
def myloss(...):
... complex operations ...
aij2 = torch.clip( aij2, 1e-9, 1-1e-9)
A_aij = torch.sqrt(1-aij2)
... more operations ...
Then the gradient of the loss systematically contains Nans.
If I replace the sqrt
by an other operation or remove @torch.jit.script
or replace it by @torch.compile
then the gradients are free of Nan (as far as my tests go).
Is this expected ? fixable ?
The full operation is quite heavy (and requires the sqrt) so I was hoping jit.script
would help. But maybe torch.compile
is a better option ? (it takes longer to start-up though…)
Thanks for any hint !