Really low accuraccy of 0%

I reproduced your issue here:

the bug seems to happen in:
accuracy = correct_prediction.float().mean()

since the data suggests a mean different from 0.0
correct_predictions tensor([ True, True, True, ..., True, False, True], device='mps:0')