Training accuracy significantly decreases and doesn't go back up when loading from a checkpoint

I am training an encoder only transformer model that plays chess like a human, with two prediction heads that find logits over the board encoder. This is how I am saving the checkpoint:

state = {
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    checkpoint_folder = f'{config_name}/{checkpoint_folder}'
    pathlib.Path(checkpoint_folder).mkdir(parents=True, exist_ok=True)
    filename = rating + "_step_" + step + ".pt"
    torch.save(state, os.path.join(checkpoint_folder, filename))

and I load the checkpoint like this:

checkpoint = torch.load(
            CONFIG.CHECKPOINT_PATH,
            weights_only=True,
        )
        
        step = checkpoint['step']
        step = int(step)
        start_epoch = step//CONFIG.STEPS_PER_EPOCH + 1
        
        state_dict = checkpoint['model_state_dict']
        new_state_dict = {}
        for key, value in state_dict.items():
            new_key = key.replace('_orig_mod.', '')  # remove the '_orig_mod' prefix
            new_key = new_key.replace('module.', '')
            new_state_dict[new_key] = value
        model.load_state_dict(new_state_dict, strict=CONFIG.USE_STRICT)
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

I tried validating the model after loading the checkpoint, and the accuracy is fine. But when I resume training for the model with model.train(), the accuracy drops a good 5-7% and it’s on the same dataset, the accuracy doesn’t go back up to what it was before pausing training either. After letting this checkpointed model train for some 500 steps, the validation accuracy decreases from where it started, which means that there is definitely something wrong with the model weights. The learning rate is also saved and loaded. I’ve looked at a few other issues related to this, and most people fixed it by finding a difference in their dataloader for training and validation, but nothing like this exists from my case. Could someone please help me with this, I’ve been stuck on this for a few days now and I’m pretty desperate for answers.

Could you explain why you are manipulating the state_dict before loading it?
Also, do you load it with the strict option or not? While strict=False has valid use cases you should not use it in case real errors were raised due to a mismatching key etc.

In the code that I sent above, strict=True, and I am manipulating the state dict because I was saving a model that was run with
torch.compile, and before I was testing it out with ddp, so I was stripping the prefix that pytorch was appending to each key. That shouldn’t make a difference to the state dict right?