Problem loading model trained on GPU

I’m trying to load a model that has been trained on GPU intp a machine that doesn’t have CUDA available and I get the following error:

THCudaCheck FAIL file=torch/csrc/cuda/Module.cpp line=51 error=35 : CUDA driver version is insufficient for CUDA runtime version
Traceback (most recent call last):
  File "main.py", line 240, in <module>
    resnet_model = torch.load('resnet_model.pt')
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 469, in _load
    result = unpickler.load()
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 437, in persistent_load
    data_type(size), location)
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 88, in default_restore_location
    result = fn(storage, location)
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 70, in _cuda_deserialize
    return obj.cuda(device)
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/_utils.py", line 68, in _cuda
    with torch.cuda.device(device):
  File "/home/diego/anaconda3/lib/python3.6/site-packages/torch/cuda/__init__.py", line 225, in __enter__
    self.prev_idx = torch._C._cuda_getDevice()
RuntimeError: cuda runtime error (35) : CUDA driver version is insufficient for CUDA runtime version at torch/csrc/cuda/Module.cpp:51

Is this the expected behaviour? Can’t I load a model trained on GPU into a machine with no GPU?

Try to load your state_dict or model with these arguments to force all tensors to be on CPU:

torch.load('my_file.pt', map_location=lambda storage, loc: storage)
8 Likes

Thank you @ptrblck :smiley:

I tried to do what you advised, but I keep getting the same error.

Here’s my code:

model_transfer.load_state_dict(torch.load('model_transfer.pt'), map_location=lambda storage, loc: storage)
model_transfer.eval()