Error when load torch.nn.DataParallel model

model_0 = AlexNet()
model_1 = torch.nn.DataParallel(AlexNet(), device_ids=[0, 1])
model_1_dict = model_1.state_dict()
model = model_0.load_state_dict(model_1_dict) # Here will be an error

the keys of model_0 and model_1 are different, as shown below

DataParallel add ‘module’ to the key’s name.