Model performance decrease to nearly 1/4 when loading a checkpoint, but works fine for "simpler" data and in-script

Hi. Sorry for the unclear title. I’m currently running a named entity recognition (NER) task with a custom dataset. All is well and validation set/test set accuracy seems to be fine when running “in-script” (I’m not even sure if this is the right way to describe it, but I mean that I’m saving the best model and running it on the test set straight after the entire training is complete). However, whenever I try to load that same checkpoint and perform inference on the validation or test set separately, the performance drops drastically (we’re talking about going from around 90% F1 to maybe 20% F1).

I’ve tried checking whether my model saving/loading code is wrong (although it’s the same that I use for other properly-working projects) but the performance is fine whenever I test it on a simpler dataset containing only one label, which tells me that the saving and loading part is fine.

I’ve also double-checked to make sure that the data isn’t “weird” in the sense that the label indices get mixed up during each run, there’s test set leakage, etc. and it all seems fine.

I’m not sure if this is really an answerable topic, but at this point I’m just hoping that anybody would be able to suggest anything that I haven’t checked yet.

Thanks! Any suggestions or tips appreciated.

Edit
For context, the way I’m saving/loading the model is torch.save(self.model.state_dict(), PATH) (model is an attribute of a larger Trainer object) and model = ModelName(args, backbone); model.load_state_dict(torch.load(PATH)). When I run evaluation, I use the evaluate method that I wrote inside of the Trainer object, and it has self.model.eval() at the start of the evaluation loop.

Try to break down the issue in the model and data part. Check the model first by using static inputs (e.g. torch.ones) in model.eval() mode and make sure both approaches yield the same results (up to the expected errors due to the limited floating point precision).
If these results match, check the data loading pipeline and try to narrow down where the difference might be coming from.

1 Like

I managed to solve the issue and for anyone wondering what the problem was, as @ptrblck suggested the part where the data was being loaded was the problem. Specifically, for some reason there was a difference between the label indices used during training and when loading the model for inference. I simply made a separate label2id.json file and stored the label information separately, which solved the issue.

1 Like