Failed to load model trained by DDP for inference

I have trained my model with 4 GPU by DDP (torch.nn.parallel.DistributedDataParallel)

But I wanna load the model for inference on machine with only single GPU , the codes are listed as below:

m = Model()
m.load_state_dict(torch.load('model.pth'))

it raises errors listed as below

RuntimeError: Error(s) in loading state_dict for ObjectDetector:
	Missing key(s) in state_dict: "backbone.block1.0.block.0.weight", "backbone.block1.0.block.0.bias", "backbone.block1.0.block.1.weight", "backbone.block1.0.block.1.bias", "backbone.block1.0.block.1.running_mean", "backbone.block1.0.block.1.running_var", "backbone.block1.1.block.0.weight", "backbone.block1.1.block.0.bias", "backbone.block1.1.block.1.weight", "backbone.block1.1.block.1.bias", "backbone.block1.1.block.1.running_mean", "backbone.block1.1.block.1.running_var", "backbone.block1.2.conv_block1.block.0.weight", "backbone.block1.2.conv_block1.block.0.bias", "backbone.block1.2.conv_block1.block.1.weight", "backbone.block1.2.conv_block1.block.1.bias", "backbone.block1.2.conv_block1.block.1.running_mean", "backbone.block1.2.conv_block1.block.1.running_var", "backbone.block1.2.conv_block2.block.0.weight", "backbone.block1.2.conv_block2.block.0.bias", "backbone.block1.2.conv_block2.block.1.weight", "backbone.block1.2.conv_block2.block.1.bias", "backbone.block1.2.conv_block2.block.1.running_mean", "backbone.block1.2.conv_block2.block.1.running_var", "backbone.block2.0.block.0.weight", "backbone.block2.0.block.0.bias", "backbone.block2.0.block.1.weight", "backbone.block2.0.block.1.bias", "backbone.block2.0.block.1.running_mean", "backbone.block2.0.block.1.running_var", "backbone.block2.1.conv_block1.block.0.weight", "backbone.block2.1.conv_block1.block.0.bias", "backbone.block2.1.conv_block1.block.1.weight", 

.......
"backbone.block5.1.conv_block1.block.1.running_var", 

I have tried

m.load_state_dict(torch.load('model.pth'), strict=False)

But the inference results are very strange not as expected

How should I fix it? thanks

I have resolved the issues. It’s related to 1686

state_dict = torch.load(weight_path)

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
    new_state_dict[name] = v

m.load_state_dict(new_state_dict)
2 Likes