Is the loss immediately NaN upon the first iteration (before any weight updates)? I would inspect the intermediate outputs of the model and see where the first NaN originates. You may want to use a utility function like torch.isfinite — PyTorch 2.0 documentation to help with this.
If the loss doesn’t start out as NaN but becomes NaN after some iterations, I would check the learning rate schedule of your optimizer and see if decreasing the learning rate helps.