C++ traced model crashes on CPU

Hi.
I have been using tracing to export pytorch models to C++ libtorch with no trouble so far on several different models. I usually train and export the models on linux GPU, and actually use them on windows, CPU. All good.

I however ran into a crash trying to do the same thing with this model:

  • Tracing and running the libtorch model on GPU is fine
  • on CPU, the model crashes when calling the forward() method.

I then tried to generate the traced model on CPU. Using this cpu model I am able to:

  • load and run the traced model in python. Feeding it with the expected entries (i.e. a couple (image, mask)) produces the expected result
  • load and run the model in c++ with dummy entries (random at::tensor of the correct size)

however, loading the model and running it with correct entries (image mask) crash when calling forward()
using this cpu model on GPU does not crash but produces erroneous output (black or uninitialized image output)

I am not able not isolate where the error might come from, although I suspect something goes wrong in the model tracing. I do not have any error message while tracing.
Any ideas ?
I’m using pytorch 1.3 as well as libtorch1.3 on both linux and windows

I am loading the network like this:

def load_network(model_path, device): 
    if 'cuda' in device:
        torch.backends.cudnn.benchmark = True
    else:
        torch.backends.cudnn.benchmark = False
            
    inpainting_model = PConvUNet().to(device)
    inpainting_model.eval()
    load_ckpt(model_path, [('model', inpainting_model)], device=device)

def load_ckpt(ckpt_name, models, optimizers=None, device=None):
    ckpt_dict = torch.load(ckpt_name, map_location=device)
    for prefix, model in models:
        assert isinstance(model, nn.Module)
        model.load_state_dict(ckpt_dict[prefix], strict=False)
    if optimizers is not None:
        for prefix, optimizer in optimizers:
            optimizer.load_state_dict(ckpt_dict[prefix])
    return ckpt_dict['n_iter']

and then just calling jit.trace:

net = load_network(net_path, device)
#dummy input for tracing the model
image = torch.rand(1, 3, inpainting_size[0], inpainting_size[1])
mask = torch.rand(1, 3, inpainting_size[0], inpainting_size[1])
                
image = image.to(device)
mask = mask.to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(net, (image, mask))
traced_script_module.save("model/InpaintNet_cpp.pt")

Could you post the error message, please?
Also, what’s the difference between the random input tensors and the correct entries?
While the first approach seems to work, the latter crashes, correct?

Yes, all these different configurations make the problem a bit confusing. I am testing on 2 different machines. Linux with a GPU, and Windows with no GPU. Here is the state of it right now in C++:

// crash results on CPU/Windows:
DUMMY_INPUT + CPU_MODEL -> no crash
REAL_INPUT + CPU_MODEL-> crash
DUMMY_INPUT + GPU_MODEL-> crash
REAL_INPUT + GPU_MODEL->crash

// crash results on GPU/Linux
DUMMY_INPUT + CPU_MODEL-> no crash
REAL_INPUT + CPU_MODEL-> no crash, wrong result
DUMMY_INPUT + GPU_MODEL-> no crash
REAL_INPUT + GPU_MODEL->ok, good result

What i call “CPU_MODEL” is the model traced using the CPU.

I do not have a specific error message other than segmentation fault (core dumped)
Debugging with Visual C++ raises an exception in kernel_lambda.h, in this specific function in class WrapRuntimeKernelFunctor_:

auto operator()(Parameters... args) -> decltype(std::declval<FuncType>()(std::forward<Parameters>(args)...)) {

      return kernel_func_(std::forward<Parameters>(args)...);

The thing is that if I load the CPU model on my Windows machine using python with

torch.jit.load

then, the output is correct, which seems to indicate that model tracing was done correctly. A libtorch problem then ?

Hard to tell, where this error is coming from, but please feel free to create an issue in GitHub.
If possible, attach a code snippet to reproduce this error.

for the record, I managed to debug this after a while.
There were two issues:
I was creating a tensor from an opencv matrix using torch::from_blob(). This function does not copy the data, so when the opencv matrix goes out of scope, the tensor points to garbage. That does not happen on GPU because the data is copied into the GPU memory before the opencv matrix goes out of scope.
Cloning the tensor after the torch::from_blob() solved the crash on CPU with the CPU-traced model

The second issue is that in python model, some tensors are created on the fly in the forward method, They are created as cuda tensors and seem to be traced with a cuda flag (?). When using on CPU, this results in:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I am happy to use a separate cpu model for now, but I guess this is solvable by moving these tensors to the right device

1 Like