Load model trained on GPU (via c++ frontend) into CPU (also via C++ frontend)

I run my code on several nodes on a cluster. On GPU enabled nodes, I train a nn, and save the parameters via torch::save. Now I want to load them on a CPU enabled node, where I have no GPU support. I get errors complaining that I don’t have the GPU support (which is correct). I think I need to tell torch::load that it needs to convert what it loads to CPU model. But I can’t find anywhere how.

For python, the approach is clear in a few posts. But how to do this for c++ frontend.

I also tried converting the model to cpu (with model->to(device)), where device was a CPU device, before saving it on GPU. But that didn’t help either. Any ideas on how this can be done?

Error I get while loading:
terminate called after throwing an instance of ‘c10::Error’
what(): Cannot initialize CUDA without ATen_cuda library. PyTorch splits its backend into

1 Like

So I ran some more tests. This problem only occurs when a model was trained at one point on the GPU. If I use the GPU version of pytorch c++, but never train on GPU (just CPU), and save the model, then it is possible to load the model using the pytorch CPU-only version. But once a model is trained on GPU, saving it results in a file that cannot be loaded with the CPU only version without triggering the warning. I would be very much helped with a workaround. Thanks!

Since my models are not very complex (i use a dense linear architecture), it would be very simple to overcome this problem by writing my own “save” and “load” functions. (Just saving the entries of the linear layers and biases…) I am about to do so, but it would be a waste if I am overlooking something.

Can anybody confirm that the above is at present an issue?

Ok, somehow I missed a rather obvious solution: using the GPU pytorch library on a CPU only device. This is possible, and allows me to load whatever I want.

I used it as below and it worked.

 c10::Device device(c10::DeviceType::CPU);
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1], device);