NaN in gradients when using jit.script

Dear everyone,

I’m seing NaN errors in the gradient calculations of my loss function when using “jit.script” on my loss function.
I’ve narrowed the problem down to a sqrt :

def myloss(...):
    ... complex operations ...
    aij2 = torch.clip( aij2, 1e-9, 1-1e-9)
    A_aij = torch.sqrt(1-aij2)
   ... more operations ...

Then the gradient of the loss systematically contains Nans.
If I replace the sqrt by an other operation or remove @torch.jit.script or replace it by @torch.compile then the gradients are free of Nan (as far as my tests go).
Is this expected ? fixable ?
The full operation is quite heavy (and requires the sqrt) so I was hoping jit.script would help. But maybe torch.compile is a better option ? (it takes longer to start-up though…)

Thanks for any hint !