LibTorch, convert deeplabv3_resnet101 to c++

I am trying to create an application that runs inference using the deeplabv3 model, in c++ using libTorch. First, i am trying to convert the model for use in c++. Based on the example code, i have:

import torch
import torchvision
from torchvision import models

model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()


# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

traced_script_module.save("model.pt")

When i run this, I get the error:

Traceback (most recent call last):
  File "convert.py", line 15, in <module>
    traced_script_module = torch.jit.trace(model, example)
  File "C:\Python37\lib\site-packages\torch\jit\__init__.py", line 636, in trace
    var_lookup_fn, _force_outplace)
RuntimeError: Only tensors and (possibly nested) tuples of tensors are supported as inputs or outputs of traced functions (toIValue at C:\a\w\1\s\windows\pytorch\torch/csrc/jit/pybind_utils.h:91)
(no backtrace available)

What am i missing?

Thank you.

trace only supports modules that have tensor or tuple of tensor as output.
According to deeplabv3 implementation, its output is OrderedDict. That is problem.
To solve, make wrapper module

class wrapper(torch.nn.Module):
    def __init__(self, model):
        super(wrapper, self).__init__()
        self.model = model
    
    def forward(self, input):
        results = []
        output = self.model(input)
        for k, v in output.items():
            results.append(v)
        return tuple(results)

model = wrapper(deeplap_model)
#trace...

I assume that every value in orderdict is tensor.

Thank you! That works, and has saved out the model.

I now need to pass a converted image to the model to run inference.
I have this, but it gives me a segfault.

	std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("model.pt");
	module->to(torch::kCUDA);
	assert(module != nullptr);
	std::cout << "ok\n";


	cv::Mat image;
	image = cv::imread("pic.jpeg", 1);
	cv::Mat image_resized;
	cv::resize(image, image_resized, cv::Size(224, 224));
	cv::Mat image_resized_float;
	image_resized.convertTo(image_resized_float, CV_32F, 1.0 / 255);

	auto img_tensor = torch::from_blob(image_resized_float.data, { 1, 224, 224, 3 }).to(torch::kCUDA);
	cout << "img tensor loaded..\n";


	img_tensor = img_tensor.permute({ 0, 3, 1, 2 });
	img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
	img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
	img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
	auto img_var = torch::autograd::make_variable(img_tensor, false);

	vector<torch::jit::IValue> inputs;
	inputs.push_back(img_var);
	torch::Tensor out_tensor = module->forward(inputs).toTensor(); //fault
	cout << out_tensor.slice(1, 0, 10) << '\n';

Where am i going wrong?
Thank you again for your time.

…additionally, when i try with example data, like this:

std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("model.pt");
	module->to(torch::kCUDA);
	assert(module != nullptr);
	std::cout << "ok\n";


	std::vector<torch::jit::IValue> inputs;
	inputs.emplace_back(torch::rand({ 64, 3, 224, 224 }));

	module->forward(inputs).toTensor();

I get the same crash. Do I need to create a model class in c++ ? Or should this be working? Thank you!

@anti This should be working. Do you have a backtrace for the segfault?

Hi, thanks for your reply. With the ‘torch.rand’ data as above, I get:

Unhandled exception at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: std::runtime_error at memory location 0x00000098549B7A80.

the output is:

Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA678.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: torch::jit::constant_not_supported_error at memory location 0x00000098549BA648.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: c10::Error at memory location 0x00000098549BA210.
Exception thrown at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: std::runtime_error at memory location 0x00000098549B7A80.
Unhandled exception at 0x00007FFE50BD3FB8 in Segmentation.exe: Microsoft C++ exception: std::runtime_error at memory location 0x00000098549B7A80.

…in the call stack, I have:

`torch.dll!torch::jit::InterpreterStateImpl::handleError(std::basic_string<char,std::char_traits,std::allocator > && error_msg, bool is_jit_exception) Line 750 C++