Missing keys & unexpected keys in state_dict when loading self trained model

You can replace module keys in state _dict as follows:-

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)

Ideally, if you use DataParallel save the checkpoint file as follows for inference:-
torch.save(model.module.state_dict(), 'model_ckpt.pt') .

This might also be useful for running inference using CPU, at a later time.

Ref: https://stackoverflow.com/a/61854807/3177661

3 Likes