Training Issues in 'MPS': -Inf and Nan

I am training NanoGPT with a dataset of COVID-19 Research papers.
The input file is ~5gb:

I can train on 200,000 epochs with the CPU, but using device=‘MPS’ training gets exceptions with -inf and nans after about 20,000 epochs.

I set fused=False in the AdamW() optimizer.
Here’s the stack trace:

/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/autograd/__init__.py:204: UserWarning: Error detected in LogSoftmaxBackward0. Traceback of forward call that caused the error:
  File "/Users/davidlaxer/nanoGPT/train.py", line 341, in <module>
    logits, loss = model(X, Y)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/davidlaxer/nanoGPT/model.py", line 220, in forward
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/nn/functional.py", line 3034, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/Users/davidlaxer/nanoGPT/train.py", line 377, in <module>
    scaler.scale(loss).backward() # dbl 3/14/23
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/Users/davidlaxer/anaconda3/envs/AI-Feynman/lib/python3.10/site-packages/torch/autograd/__init__.py", line 204, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.

Process finished with exit code 1

I first saw -Inf in the LayerNorm layer which lead to nan’s after F.LayerNorm() was called.

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)