Hey!
I trained a model on 2 GPU’s using the DataParallel function, and saved the model state dict using:
torch.save(model.state_dict(), PATH)
For some reason when loading the dict into a model using:
model = models.segmentation.deeplabv3_resnet101(pretrained = True)
model.load_state_dict(torch.load(PATH))
I get a key mismatch. I realized this was due to how I saved my model after using DataParallel, which in retrospect should have been:
torch.save(model.module.state_dict(), PATH)
Does anyone know a way to salvage this situation? i.e how I can load my state dict into a model, even though I saved the model differently from the documentation?
For sanity’s purposes, I also saved a model using model.module.state_dict()
while using DataParallel
, and again using model.state_dict()
using only a single GPU, in which both cases loaded the state_dict with no issues at all. Help?