Nan Loss with torch.cuda.amp and CrossEntropyLoss

In my case, I think nan is caused by the loss is too large to be held by float16.