TorchScript trace in Python while model on GPU gives CUDA error in C++

Python

With PyTorch (Lightning) I save a model during training to TorchScript with Trace in a Callback (hence the model is on the GPU):

class SimpleModel(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.layer = torch.nn.Linear(10, 3)

model = SimpleModel()
# in training:
    torch.jit.trace(func=model.eval(), example_inputs=example_inputs)

C++

Now I want to load that trained model in C++ with:

torch::jit::script::Module module;
try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    // torch::Device device(torch::kCPU);
    module = torch::jit::load(argv[1]); //->to(at::kCPU);  // , map_location=torch::device("cpu")
}
catch (const c10::Error& e) {
    std::cerr << "error loading the model\n" << e.what();
    return -1;
}

which gives me the error:

error loading the model
Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. 'aten::empty_strided' is only available for these backends: [CPU, BackendSelect, Autograd, Profiler, Tracer].
Exception raised from reportError at ../aten/src/ATen/core/dispatch/Dispatcher.cpp:313 (most recent call first):
...

Traceback

Full Traceback (CLICK ME)

$ ./torchscript_simple ../../test.pt
ok
error loading the model
Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. 'aten::empty_strided' is only available for these backends: [CPU, BackendSelect, Autograd, Profiler, Tracer].
Exception raised from reportError at ../aten/src/ATen/core/dispatch/Dispatcher.cpp:313 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f516d348eb9 in /project_path/cpp_simple/libtorch/lib/libc10.so)
frame #1: c10::Dispatcher::reportError(c10::DispatchTable const&, c10::DispatchKey) + 0x3dc (0x7f515dc36a5c in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x11a8851 (0x7f515e2cf851 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x11a9e45 (0x7f515e2d0e45 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #4: at::empty_strided(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) + 0x10b (0x7f515e3e5cdb in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x2e4603d (0x7f515ff6d03d in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x11a9e45 (0x7f515e2d0e45 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #7: at::empty_strided(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) + 0x10b (0x7f515e3e5cdb in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #8: at::native::to(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) + 0x49f (0x7f515e00cecf in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x135a2d1 (0x7f515e4812d1 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x2ee0e01 (0x7f5160007e01 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x13a2eca (0x7f515e4c9eca in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #12: at::Tensor::to(c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) const + 0x146 (0x7f515e541eb6 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x3722871 (0x7f5160849871 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x3723940 (0x7f516084a940 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x3723ef1 (0x7f516084aef1 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #16: torch::jit::readArchiveAndTensors(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<std::function<c10::StrongTypePtr (c10::QualifiedName const&)> >, c10::optional<std::function<c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > (c10::StrongTypePtr, c10::IValue)> >, c10::optional<c10::Device>, caffe2::serialize::PyTorchStreamReader&) + 0x6b2 (0x7f51607eb642 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x36c495d (0x7f51607eb95d in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x36c71ff (0x7f51607ee1ff in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #19: torch::jit::load(std::unique_ptr<caffe2::serialize::ReadAdapterInterface, std::default_delete<caffe2::serialize::ReadAdapterInterface> >, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x179 (0x7f51607ee7a9 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #20: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x75 (0x7f51607f1355 in /project_path/cpp_simple/libtorch/lib/libtorch_cpu.so)
frame #21: main + 0xd6 (0x557c5ac9c4df in ./torchscript_simple)
frame #22: __libc_start_main + 0xe7 (0x7f515c5b4b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #23: _start + 0x2a (0x557c5ac9bb5a in ./torchscript_simple)

Loading model onto a specific device?

It seems C++ does not have an argument like map_location as in Python: torch.jit.load(ts_location, map_location=sample.device)

Question

How should I load the model on CPU avoiding a CUDA error (without needing to transfer the model to the CPU in Python when saving as TorchScript)?

Currently I’m using the workaround where I once again load the TorchScript into Python, but map it onto the cpu, then save it. The new model can successfully be used in C++

import torch

# load TorchScript model on cpu
torchmodel = torch.jit.load("path/to/model.pt", map_location='cpu')

# save, but with no CUDA code
torch.jit.save(torchmodel, "path/to/model_cpu.pt")

Still I would like to avoid this workaround if possible.

Hi, is there any stable solution to this problem? I mean sending torchscript models from GPU to cpu and vice versa, without any saving and loading process.