Loading weights for CPU model while trained on GPU

This is not a very complicated issue, but I am not sure what is the best way to load the weights into the cpu when the model was trained on a GPU, thus here is my solution:

model = torch.load('mymodel')
self.model = model.cpu().double()

I am not sure if this should be a bug, also this discussion is related link.

2 Likes

There is an option in torch.load to map the parameters from one specified device to another.
As pointed out in the link you send, there is a way of forcing all GPU tensors to be in CPU while loading, which I copy here:

torch.load('my_file.pt', map_location=lambda storage, loc: storage)
18 Likes

Quick question… this only seems to work when the model is trained on one GPU. If I train my model on multiple GPUs, save it, and then try to load on the CPU I get this error: KeyError: 'unexpected key "module.conv1.weight" in state_dict' Is there something different that needs to happen during saving/loading for multiple GPUs?

This is a different issue, and is related to [solved] KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'

Perfect, thanks for the clarification. I ended up using the conversion suggested in your linked post.

Thanks for the suggestion!

Wait, I don’t get it…
So in the lambda:

torch.load('my_file.pt', map_location=lambda storage, loc: storage)

storage and loc are supposed to be replaced by the variable I want to store the value in and the target location (CPU or GPU number) respectively, right? But how do I specify the location? Are there any ‘keywords’ to do so?

2 Likes

I have a related question, I have a shared model trained on GPU, and another process needs this model for inference on CPU. So I use a shared model and use the following command to load this shared model

    cpu_model.load_state_dict(gpu_model.cpu().state_dict())

however, this won’t work and returns CUDA error(3), initialization error, what happened?

3 Likes

If I load model with your hack and set model.train(‘True’) and trying even inference it fails. So it do not work for all cases (finetune on CPU after training on GPU not working).