I reproduced your issue here:
the bug seems to happen in:
accuracy = correct_prediction.float().mean()
since the data suggests a mean different from 0.0
correct_predictions tensor([ True, True, True, ..., True, False, True], device='mps:0')