# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
# original saved file with DataParallel
ckpt = torch.load("pretraied_model.pth")
state_dict = ckpt['model'] # incase there are extra parameters in the model
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.encoder.","encoder.")
new_state_dict[name] = v
# load params
model.encoder.load_state_dict(new_state_dict,strict=False)
Thanks to post