Loading a model_state_dict after training the model using nn.DataParallel()


I trained a model on 2 GPU’s using the DataParallel function, and saved the model state dict using:

    torch.save(model.state_dict(), PATH)

For some reason when loading the dict into a model using:

model = models.segmentation.deeplabv3_resnet101(pretrained = True)

I get a key mismatch. I realized this was due to how I saved my model after using DataParallel, which in retrospect should have been:

torch.save(model.module.state_dict(), PATH)

Does anyone know a way to salvage this situation? i.e how I can load my state dict into a model, even though I saved the model differently from the documentation?

For sanity’s purposes, I also saved a model using model.module.state_dict() while using DataParallel, and again using model.state_dict() using only a single GPU, in which both cases loaded the state_dict with no issues at all. Help?

Figured it out for whoever is interested:

when training a model using DataParallel, in order to load the state_dict onto a model running on the CPU, you must save the model params using torch.save(model.module.state_dict(), PATH).

if you saved your model using torch.save(model.state_dict(), PATH), then when loading the weights into the model, you must first send your model to multiple GPU’s using the DataParallel method, and only then load the state dict.