PyTorch 2.x causes divergence during training with mixed precision

I was previously using PyTorch 1.13. I have a regular mixed precision setup where I use autocast. There are noticeable speed ups with mixed precision enabled, so everything works fine.

However, I need to update my PyTorch version to 2.5+. When I do this, my training losses start increasing a lot around 25000 iterations. Disabling mixed precision resolved the issue, but I need it for training speed. I tried 2.5 and 2.6. Same issue happens with both.

My model contains transformers.

I tried using bf16 instead of fp16, it started diverging even earlier (around 8000 iterations).

I am using GradScaler, and I logged its scaling factor. When using fp16, It goes as high as 1 million, and quickly reduces to 4096 when divergence happens. When using bf16, scale keeps increasing even after divergence happens.

I also have the following, not sure if it can cause this issue:

    torch.backends.cudnn.enabled = True 
    torch.backends.cudnn.benchmark = True 
    torch.autograd.set_detect_anomaly(False) 
    torch.autograd.profiler.profile(False) 
    torch.autograd.profiler.emit_nvtx(False) 

What might be the issue? I can provide further details if necessary. Thanks in advance.

1 Like

I tried everything at this point. I am not sure what causes this. Any help would be appreciated.