DataParallel changes parameter names, issue with load_state_dict()

Im working on multigpu, single gpu and cpu on the same model.
I noticed that dataparallel changes the parameter names appending the names to a “module.” prefix.
eg. from “conv1.weight” to “module.conv1.weight”
This is fine for training, until one needs to load the parameters in a model that is not running DataParallel.

Short of rebuilding the OrderedDict splitting “module.” out of the names; or building a generalized trained model loader that matches partial names and tensor shapes, Is there a generalized way of handling this?
its no biggie, but i rather use a generalized handler if one is available. Im testing inference times on cpu only devices but my training setup is multigpu.

thanks

You can just modify dic names by hand and manually removing module.

            state_dict = torch.load(directory, map_location=lambda storage, loc: storage)
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            if k[:7] == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v

right, thanks, i was asking if there was a generalized loader that didnt require that. I was hacking it like this,

from collections import OrderedDict as OD
#...
          if device == "cuda":
                self.model = torch.nn.DataParallel(self.model)
                self.model.load_state_dict(state_dict)
            else:
                state = OD([(key.split("module.")[-1], state_dict[key]) for key in state_dict])

                self.model.load_state_dict(state)
#...
1 Like