Good performance during evaluation, but poor performance during inference

Hi, newbie here.

I tried to pre-train and implement JPEG-LM from scratch using Llama-2. The code can be found here. During evaluation and training, it returns high accuracy, which might indicate overfitting. However, when I perform inference on the training data, the performance is poor (you can check in the code). Did I misimplement the Llama-2 code or the inference code?

You can check the WandB reports here. I’m confused about what is happening.