Loading weights from DataParallel models

Hi everyone,

I know that in order to load weights for CPU model which was saved during training by 1 GPU, we can use 2 lines below:

net = Model()  # My own architecture I define
model_path = "path/to/model.pkl"
state_dict = torch.load(model_path, map_location={"cuda:0" : "cpu"}
net.load_state_dict(state_dict)

However, when I train model on 2 GPUs using DataParallel to wrap my net model, then saving with

net = Model()
net = torch.nn.DataParallel(net)
net.cuda()

# Training
training()

# Saving
model_path = "path/to/model.pkl"
torch.save(net.state_dict(), model_path"

I load it back for later use like below:

net = Model()
net.load_state(torch.load(model_path, map_location={"cuda" : "cpu"})

It doesn’t map correctly the weights and keys in the state_dict. So what is the solution for it? Should I change the value for map_location to different one?
Thank you.

4 Likes

For later use, if I define the model in in DataParallel and move it to 2 GPUs, there is no problem:

net = Model()
net = DataParallel(net)
net.cuda()
model = "path/to/model.pkl"
state_dict = torch.load(model_path)
net.load_state_dict(state_dict)

It works well in this case even I don’t need to pass a value for map_location. But my concern is how to load weights on CPU-only machine, or 1-GPU only machine. I have a temporary solution which I don’t think, is the best one.

class WrappedModel(nn.Module):
	def __init__(self):
		super(WrappedModel, self).__init__()
		self.module = Model() # that I actually define.
	def forward(self, x):
		return self.module(x)

# then I load the weights I save from previous code:
net = WrappedModel()
net.load_state_dict(torch.load(model_path, map_location={"cuda" : "cpu"})

The keys in state_dict map perfectly and works but it is kind of lengthy :smiley:

4 Likes

I change the code, and it works, thankyou.

class WrappedModel(nn.Module):
	def __init__(self, module):
		super(WrappedModel, self).__init__()
		self.module = module # that I actually define.
	def forward(self, x):
		return self.module(x)

model = getattr(models, args.model)(args)
model = WrappedModel(model)
state_dict = torch.load(modelname)['state_dict']
model.load_state_dict(state_dict)
3 Likes