Moving model to CUDA in C++

I understand that loading a model made with torch.jit will always load it into CPU by default, but what should I do in C++ as the equivalent of .cuda() please? Sorry for asking something so basic.

5 Likes

try tensor.to(kCUDA)

edit: sorry, didn’t realize you were trying to move a module in c++.

I’m following the tutorial called Loading a Pytorch Model in C++. I was able to trace the python model, export it to the .pt format, load it in C++ and perform inference. That all worked fine on the CPU.

However, I want to do the same thing on the GPU. Following the above suggestion, I made one modification to the .cpp file. I took:
inputs.push_back(torch::ones({1, 3, 224, 224}))
And made it:
inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA))

I also added the following include:
#include <ATen/Aten.h>

I then reran make in my build directory. It built. However, here is what happens when I try to run the executable:

$ ./example-app ../model.pt 
ok
terminate called after throwing an instance of 'at::Error'
  what():  Expected object of backend CUDA but got backend CPU for argument #2 'weight' (checked_tensor_unwrap at /pytorch/aten/src/ATen/Utils.h:70)
frame #0: at::CUDAFloatType::thnn_conv2d_forward(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>) const + 0xb5 (0x7f47506484c5 in /home/abweiss/libtorch/lib/libcaffe2_gpu.so)
frame #1: torch::autograd::VariableType::thnn_conv2d_forward(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>) const + 0x55f (0x7f477f1729df in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #2: at::TypeDefault::thnn_conv2d(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>) const + 0x73 (0x7f4774fd1933 in /home/abweiss/libtorch/lib/libcaffe2.so)
frame #3: torch::autograd::VariableType::thnn_conv2d(at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>) const + 0x179 (0x7f477f0d01b9 in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #4: at::native::_convolution_nogroup(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>) + 0x75f (0x7f4774d1490f in /home/abweiss/libtorch/lib/libcaffe2.so)
frame #5: at::TypeDefault::_convolution_nogroup(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>) const + 0x6d (0x7f4774fb680d in /home/abweiss/libtorch/lib/libcaffe2.so)
frame #6: torch::autograd::VariableType::_convolution_nogroup(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>) const + 0x1ae (0x7f477f0c0d2e in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #7: at::native::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long, bool, bool, bool) + 0x1b48 (0x7f4774d18b38 in /home/abweiss/libtorch/lib/libcaffe2.so)
frame #8: at::TypeDefault::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long, bool, bool, bool) const + 0x93 (0x7f4774fb68e3 in /home/abweiss/libtorch/lib/libcaffe2.so)
frame #9: torch::autograd::VariableType::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::ArrayRef<long>, at::ArrayRef<long>, at::ArrayRef<long>, bool, at::ArrayRef<long>, long, bool, bool, bool) const + 0x22d (0x7f477f0c102d in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #10: <unknown function> + 0x47410b (0x7f477f27610b in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #11: <unknown function> + 0x4a87ed (0x7f477f2aa7ed in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #12: <unknown function> + 0x49486c (0x7f477f29686c in /home/abweiss/libtorch/lib/libtorch.so.1)
frame #13: torch::jit::script::Method::run(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&) + 0xf6 (0x42ba88 in ./example-app)
frame #14: torch::jit::script::Method::operator()(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >) + 0x4a (0x42bb16 in ./example-app)
frame #15: torch::jit::script::Module::forward(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >) + 0x81 (0x42c325 in ./example-app)
frame #16: main + 0x1fe (0x42787e in ./example-app)
frame #17: __libc_start_main + 0xf0 (0x7f474ee32830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #18: _start + 0x29 (0x426cb9 in ./example-app)

Aborted (core dumped)

Are there additional steps required to make this run on the GPU? It seems that I have moved the input Tensor to GPU, but the model weights are still on the CPU. How do I move the whole model to GPU?

3 Likes

cc @goldsborough on c++ api

A couple of follow up comments/questions:

First off, I tried moving my loaded Module to the GPU by doing the following:

  std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
  module->to(at::kCUDA);

However, this yields a compile error:

$ make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
/home/abweiss/pt1_example_0/example-app.cpp: In function ‘int main(int, const char**)’:
/home/abweiss/pt1_example_0/example-app.cpp:16:11: error: ‘struct torch::jit::script::Module’ has no member named ‘to’
   module->to(at::kCUDA);

So apparently torch::nn::Module has a to function, but not torch::jit::script::Module.

Second, I have a question. What is the difference between at::kCUDA and torch::kCUDA? When should I be using one instead of the other?

1 Like

It’s suggested in TensorOptions.h that the rule of thumb is that if there’s an at:: and torch:: version of the same function then you should use the at:: version with Tensors and the torch:: version with Variables.

1 Like

Does anyone know if it is even possible to move a traced model to the GPU in C++? Maybe this is just not possible in the preview version of Pytorch 1.0. It would be really nice if someone could clarify.

5 Likes

I check the pyotrch code and find in python code torch.jit.load return torch.jit.ScriptModule which is
inheritted from torch.nn.Module,so i can use torch.jit.ScriptModule like torch.nn.Module,

traced_script_module = torch.jit.load("model.pt")
traced_script_module.cuda()

but in C++ torch::jit::load return std::shared_ptr<lt;torch::jit::script::Module> which is a struct Module.https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/script/module.h

2 Likes

Might the answer be that we’re supposed to move certain parts within the module to cuda rather than the whole thing?

1 Like

I’m still hoping to get a definitive answer on this from one of the developers (@goldsborough? @smth?), or from someone who has already made this work:

In the current preview version Pytorch 1.0, is it possible to move a traced model to the GPU in C++?

If so, how? If not, is there an alternative way to get a C++ GPU model?

These are really fundamental questions for anyone who has been waiting for a way to productionize their Pytorch models.

3 Likes

Dear @richard & @ezyang,
please help us!

I’ve opened an issue at https://github.com/pytorch/pytorch/issues/12686

Seems to me that at https://github.com/pytorch/pytorch/issues/12563 they are talking about the python interface.

Yeah, the fact that traced modules are loaded back onto the CPU by default was explained at the developer conference. It’s somewhere in https://www.facebook.com/pytorch/videos/482401942168584/ if you’re interested.

How to move them from CPU to CUDA in python is well documented. C++ seems to have slipped through the net.

Excellent! my doombots will soon be…

Um, I mean… oh good, there’s a patch. That’s promising.

Hi everyone, sorry for being late to the party. I think there are two broad questions that were asked in this thread:

  1. How can I move my script::Module to CUDA? We found that there indeed was no easy way to do this without iterating over parameters yourself, so I went a head and implemented script::Module::to(...) in https://github.com/pytorch/pytorch/pull/12710. We’ll try to land it today or tomorrow.

  2. Some of you noticed that the torch::nn::Module class from the C++ frontend has a to(...) method, and you were wondering whether you could mix torch::nn::Module and script::Module. At the moment, there is a strict division between the torch::nn::Module, which is for the C++ frontend (the pure C++ alternative to the Python eager frontend), and script::Module (the C++ module class for TorchScript). They may not be mixed at the moment. The torch::nn::Module class is currently friendlier to use because it’s meant to provide the same API as torch.nn.Module in Python, for research. We are working actively on blending the TorchScript C++ API with the C++ frontend API, so I would expect torch::nn::Module and script::Module to become largely the same in the next few months. Feel free to raise more issues for operations you’d like to do on script::Module that are currently hard to do, and we can send more patches.

Hope this helps and let me know if you have more questions.

4 Likes

@goldsborough Thank you for the clarifications, and especially for the PR!

you can try load model directly into GPU, like:

std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1], torch::kCUDA);

The reply to this post, did the trick for me