Convert existing model/weights fromData Parallel to generic?

Sorry if this isn’t the correct topic. I did try to look for my problem, but I don’t I think I’m using the right search terms, which is why I’m here.

So, I trained a model using nn.DataParallel and it worked fine. I was able to inference fine on my machine, but only my machine. I finally figured out that I mistakenly did NOT save my model as generic as described here: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models

Is there a way to convert my existing model to a generic form so it can be used by anyone? Hopefully in a automated way and not me sludging through the weights manually?

Are you still able to load the checkpoint? DataParallel just hides the actual model away under the attribute module. So you can do something like this:

my_model = ModelClass(...)
my_model = DataParallel(my_model, ...)

state_dict = torch.load(...)
my_model.load_state_dict(state_dict)

actual_model = my_model.module
new_state_dict = actual_model

Haven’t tested this, but I bet something like this would work as well:

state_dict = torch.load(...)

new_state_dict = {}
for k, v in state_dict.items():
    k = k.replace('module.', '')  # Each weight k will be prefix with the word "module"
    new_state_dict[k] = v

model = SomeModel(...)
model.load_state_dict(new_state_dict)
1 Like

Thanks, we tried something similar to the second method and it “works”, but unsure if the model is accurately predicting or not. Still investigating… Thanks again!