You can replace module keys in state _dict as follows:-
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)
Ideally, if you use DataParallel save the checkpoint file as follows for inference:-
torch.save(model.module.state_dict(), 'model_ckpt.pt')
.
This might also be useful for running inference using CPU, at a later time.