Nan values in loss after a few epochs

Finally fixed this issue. Turns out, switching to FP32 from FP16 autocasting helped resolve the issues with convergence.

1 Like