I am training an encoder only transformer model that plays chess like a human, with two prediction heads that find logits over the board encoder. This is how I am saving the checkpoint:
state = {
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
checkpoint_folder = f'{config_name}/{checkpoint_folder}'
pathlib.Path(checkpoint_folder).mkdir(parents=True, exist_ok=True)
filename = rating + "_step_" + step + ".pt"
torch.save(state, os.path.join(checkpoint_folder, filename))
and I load the checkpoint like this:
checkpoint = torch.load(
CONFIG.CHECKPOINT_PATH,
weights_only=True,
)
step = checkpoint['step']
step = int(step)
start_epoch = step//CONFIG.STEPS_PER_EPOCH + 1
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace('_orig_mod.', '') # remove the '_orig_mod' prefix
new_key = new_key.replace('module.', '')
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=CONFIG.USE_STRICT)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
I tried validating the model after loading the checkpoint, and the accuracy is fine. But when I resume training for the model with model.train(), the accuracy drops a good 5-7% and it’s on the same dataset, the accuracy doesn’t go back up to what it was before pausing training either. After letting this checkpointed model train for some 500 steps, the validation accuracy decreases from where it started, which means that there is definitely something wrong with the model weights. The learning rate is also saved and loaded. I’ve looked at a few other issues related to this, and most people fixed it by finding a difference in their dataloader for training and validation, but nothing like this exists from my case. Could someone please help me with this, I’ve been stuck on this for a few days now and I’m pretty desperate for answers.