Unrecognized data format using libtorch when loading TorchScript exported model (Python) in C++ context

Hi :slight_smile:

Setup: Win10, Visual Studio 2017, libtorch 1.12.1 (release, tried CPU and CUDA versions), CMake 3.12 (integrated in VS2017)

Goal: Export model from PyTorch (Python) to file using TorchScript and load it using libtorch (C++).

Issue:

When loading any exported module I am getting an error (see below). The actual file here is not important (although I have provided the Python code to show how I am doing it) since the error occurs with any trained model I tried (PT/PTH files downloaded from the internet). The issue is therefore probably not in the way I am saving the file but rather either the serialization or de-serialization of it.

Unrecognized data format
Exception raised from load at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\jit\serialization\import.cpp:449 (most recent call first):
00007FFEBEDADA2200007FFEBEDAD9C0 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFEBEDAD43E00007FFEBEDAD3F0 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FFE4FF2B54700007FFE4FF2B4E0 torch_cpu.dll!torch::jit::load [<unknown file> @ <unknown line number>]
00007FFE4FF2B42A00007FFE4FF2B380 torch_cpu.dll!torch::jit::load [<unknown file> @ <unknown line number>]
00007FF64CF2682B00007FF64CF266E0 pytroch_load_model.exe!main [c:\users\USER\projects\cmake dx cuda pytorch\cmake_integration_examples\pytorch\src\pytroch_load_model.cpp @ 17]
00007FF64CF51C2400007FF64CF51BF0 pytroch_load_model.exe!invoke_main [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 79]
00007FF64CF51ACE00007FF64CF519A0 pytroch_load_model.exe!__scrt_common_main_seh [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 288]
00007FF64CF5198E00007FF64CF51980 pytroch_load_model.exe!__scrt_common_main [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl @ 331]
00007FF64CF51CB900007FF64CF51CB0 pytroch_load_model.exe!mainCRTStartup [d:\agent\_work\2\s\src\vctools\crt\vcstartup\src\startup\exe_main.cpp @ 17]
00007FFEDD72703400007FFEDD727020 KERNEL32.DLL!BaseThreadInitThunk [<unknown file> @ <unknown line number>]
00007FFEDDA0265100007FFEDDA02630 ntdll.dll!RtlUserThreadStart [<unknown file> @ <unknown line number>]

Steps to reproduce:

I am using the following code:

Python

Currently I just create some modules and try to export those. Training/Evaluation as well as the actual purpose of the modules is of no importance. All of it is directly taken either from the official documentation on TorchScript or the PyTorch forums.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

class TestModel(nn.Module):
def init(self):
super(TestModel, self).init()
self.x = 2

def forward(self, x):
    return self.x

class TensorContainer(nn.Module):
def init(self, tensor_dict):
super().init()
for key,value in tensor_dict.items():
setattr(self, key, value)

test_net = Net()
print(test_net)
test_net = torch.jit.script(test_net)
torch.jit.save(test_net, ‘test_net.pt’)

test_module = TestModel()
print(test_module)
test_module = torch.jit.script(test_module)
torch.jit.save(test_module, ‘test_module.pt’)

prior = torch.ones(3, 3)
tensor_dict = {‘prior’: prior}
print(tensor_dict)
tensors = TensorContainer(tensor_dict)
print(tensors.dict)
tensors = torch.jit.script(tensors)
tensors.save(‘values.pt’)

C++

The loaded model (C++ context) is currently not used at all. The main point here is to just load the thing from a file. I tried both relative path (e.g. the EXE and the exported modules are in the same directory and I can just do binary.exe module.pt) and absolute path (just in case that libtorch doesn’t deal properly with relative paths; path does not contain any Unicode).

#include <torch/script.h>
#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
	if (argc != 2) {
		std::cerr << "Usage: pytroch_load_model <path-to-exported-script-module>\n";
		return -1;
	}

	torch::jit::script::Module module;
	try {
		std::cout << "Trying to load model..." << std::endl;
		module = torch::jit::load(argv[1]);
	}
	catch (const c10::Error& e) {
		std::cerr << "Loading failed" << std::endl;
		std::cerr << e.what() << std::endl;
		return -1;
	}

	std::cout << "Loading successful" << std::endl;
}

CMakeLists.txt

The project is a subdirectory (sub-project). The Torch libraries, headers etc. are located in the top-level directory of the main project (hence the ${CMAKE_SOURCE_DIR} call). All required libraries are linked properly (no errors during link stage) and DLLs are accordingly copied to the directory where the final EXE is located.

cmake_minimum_required (VERSION 3.12 FATAL_ERROR)

project(pytroch
  DESCRIPTION "CMake example for PyTorch (libtorch C++) integration"
  LANGUAGES CXX
)

set(CMAKE_CXX_STANDARD 14)

set(SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src")
set(CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/deps/libtorch/1.12.1/release/cpu/share/cmake/Torch")
find_package(Torch REQUIRED)
if(TORCH_FOUND)
	message(STATUS "Found Torch")
else()
	message(CRITICAL_ERROR "Unable to find Torch")
endif(TORCH_FOUND)

add_executable(pytroch_load_model
	"${SRC_DIR}/pytroch_load_model.cpp"
)
target_include_directories(pytroch_load_model PUBLIC ${TORCH_INCLUDE_DIRS})
target_link_libraries(pytroch_load_model PRIVATE ${TORCH_LIBRARIES})

file(GLOB LIBTORCH_DLLS
  "${CMAKE_SOURCE_DIR}/deps/libtorch/1.12.1/release/cpu/lib/*.dll"
)
file(COPY
	${LIBTORCH_DLLS}
	DESTINATION "${CMAKE_BINARY_DIR}/bin/"
)
1 Like

Hi there @rbaleksandar, did you happen to find a solution of the problem? I seem to have a similar issue.

I forgot to post the solution. Yes, I was able to load my model successfully. I created a new model (in my case SRCNN) that I experimented with but this should apply to any generic model out there. I am not an expert with neither libtorch, nor Torch RT, nor PyTorch and I do not claim that there will be no issues. Just a small disclaimer. :wink:

In a separate script (to keep things clean between the development of the model and its coversion) I do the following:

  1. Instantiate the model. Here it is important to always load trained model on CPU
net = SRCNN()
net.load_state_dict(torch.load(path_to_model))
net.to(torch.device('cpu'))
  1. Load an example input. Note that the input needs to fit the data you have trained your model with. I was stupid enough to train my SRCNN with grayscale images, so I cannot use it for RGB due to the difference in the number of channels, hence also the change in the shape of the tensor.
example_inputs = Image.open('data/butterfly_GT.bmp').convert('L')
example_inputs = transforms.ToTensor()(example_inputs).unsqueeze(0)
  1. Trace the model with the given sample:
net_scripted = torch.jit.trace(net, example_inputs)
  1. Save it
torch.jit.save(net_scripted, output)

As for the C++ part here is a small sample code. It uses torch::from_blob() to create a tensor from an OpenCV image (cv::Mat). torch::from_blob() can be used for any array out there as long as you give the correct options.

// Load the image for the SRCNN as grayscale since we have trained the SRCNN with single channel 8bit images
cv::Mat img = cv::imread(path_to_image, cv::IMREAD_GRAYSCALE);
// Convert to float (required by our model)
img.convertTo(img, CV_32FC3, 1.0f / 255.0f);
// Scale up to introduce larger dimensions but also noise due to the resizing. This will be the input for the SRCNN
cv::Mat img_resized;
auto scale_up = cv::Size(img.size().height * 2, img.size().width * 2);
cv::resize(img, img, scale_up);

// Upload the Torch model to the GPU
torch::jit::script::Module module;
try {
	// Deserialize the ScriptModule from a file using torch::jit::load().
	module = torch::jit::load(path_model, c10::kCUDA);	//c10::kCPU
}
catch (const c10::Error& e) {
	std::cout << "Failed" << std::endl;
	std::cerr << e.what() << std::endl;
	return -1;
}

// Adapt the tensor further so that it fits the expected SRCNN input
// Here we need to do a permutation of the channels
auto tensor_input = torch::from_blob(img.data, { 1, img.size().height, img.size().width, 1 }, torch::kFloat32);
tensor_input = tensor_input.permute({ 0, 3, 1, 2 });

// Load the input to the GPU and run it through the model
torch::Tensor tensor_output;
try
{
	tensor_input = tensor_input.to(at::kCUDA);
	tensor_output = module.forward({ tensor_input }).toTensor();
}
catch (std::runtime_error& err)
{
	std::cout << "Failed" << std::endl;
	std::cerr << err.what() << std::endl;
}
catch (std::exception& ex)
{
	std::cout << "Failed" << std::endl;
	std::cerr << ex.what() << std::endl;
}

// Post-process the tensor incl. copying to CPU, reversing the changes applied to the input to fulfill the model's requirements and so on
tensor_output = tensor_output.squeeze(0).detach().permute({ 1, 2, 0 });
tensor_output = tensor_output.mul(255).clamp(0, 255).to(torch::kU8);
tensor_output = tensor_output.to(torch::kCPU);

// Convert the tensor back to an image
cv::Mat img_out(img.size().height, img.size().width, CV_8UC1);
std::memcpy((void*)img_out.data, tensor_output.data_ptr(), sizeof(torch::kU8) * tensor_output.numel());

// Use the image
// ...

Fun fact, you can even map a CUDA array to a Tensor (again with the torch::from_blob()) and, in case you are using OpenCV’s cv::gpuMat do a lot of the pre-processing, processing and post-processing of your data completely on the GPU with very little copying back and forth. If copying is something you need to do, you can use pinned CUDA memory, which is made available to both CPU and GPU and will not be swapped by the OS, thus reducing page faults and unnecessary copying back and forth to different parts of your memory.

Here is an example for a libtorch and CUDA interoperability:

#include <cuda_runtime.h>
#include <torch/torch.h>
#include <iostream>
#include <exception>
#include <memory>
#include <math.h>

using std::cout;
using std::endl;
using std::exception;

/*
 * Demonstration of interoperability between CUDA and Torch C++ API
 * 
 * Using pinned memory (for fast transfer between CPU and GPU) three
 * vectors are created, two of which first on the CPU followed by copy
 * onto the GPU (inputs) and one directly onto the GPU (result).
 * 
 * The two inputs are then used in a simple calculation and stored into
 * the result. All is done in CUDA. The result is then converted to 
 * Torch tensors directly onto the GPU in order to demonstrate inter-
 * operability between the two APIs.
 * 
 * Using the ENABLE_ERROR variable a change in the result (CUDA) can be
 * introduced through its respective Torch tensor. This will also affect
 * the copied data from GPU to CPU, resulting in an error during assert
 * checks at the end.
 * The purpose of this experiment was to verify that the created Torch
 * tensors are actually altering the data on the GPU and not on the CPU.
 */


// Careful with libtorch library version - release vs debug, cpu vs cuda

// Contains the call to the CUDA kernel
void vector_add(int* a, int* b, int* c, int N, int cuda_grid_size, int cuda_block_size);

bool ENABLE_ERROR = false;

int main(int argc, const char* argv[])
{
	// Setup array, here 2^16 = 65536 items
	const int N = 1 << 16;
	size_t bytes = N * sizeof(int);

	// Declare pinned memory pointers
	int* a_cpu, * b_cpu, * c_cpu;

	// Allocate pinned memory for the pointers
	// The memory will be accessible from both CPU and GPU
	// without the requirements to copy data from one device
	// to the other
	cout << "Allocating memory for vectors on CPU" << endl;
	cudaMallocHost(&a_cpu, bytes);
	cudaMallocHost(&b_cpu, bytes);
	cudaMallocHost(&c_cpu, bytes);

	// Init vectors
	cout << "Populating vectors with random integers" << endl;
	for (int i = 0; i < N; ++i)
	{
		a_cpu[i] = rand() % 100;
		b_cpu[i] = rand() % 100;
	}

	// Declare GPU memory pointers
	int* a_gpu, * b_gpu, * c_gpu;

	// Allocate memory on the device
	cout << "Allocating memory for vectors on GPU" << endl;
	cudaMalloc(&a_gpu, bytes);
	cudaMalloc(&b_gpu, bytes);
	cudaMalloc(&c_gpu, bytes);

	// Copy data from the host to the device (CPU -> GPU)
	cout << "Transfering vectors from CPU to GPU" << endl;
	cudaMemcpy(a_gpu, a_cpu, bytes, cudaMemcpyHostToDevice);
	cudaMemcpy(b_gpu, b_cpu, bytes, cudaMemcpyHostToDevice);

	// Specify threads per CUDA block (CTA), her 2^10 = 1024 threads
	int NUM_THREADS = 1 << 10;

	// CTAs per grid
	int NUM_BLOCKS = (N + NUM_THREADS - 1) / NUM_THREADS;

	// Call CUDA kernel
	cout << "Running CUDA kernels" << endl;
	vector_add(a_gpu, b_gpu, c_gpu, N, NUM_BLOCKS, NUM_THREADS);

	try
	{
		// Convert pinned memory on GPU to Torch tensor on GPU
		auto options = torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA, 0).pinned_memory(true);
		cout << "Converting vectors and result to Torch tensors on GPU" << endl;
		torch::Tensor a_gpu_tensor = torch::from_blob(a_gpu, { N }, options);
		torch::Tensor b_gpu_tensor = torch::from_blob(b_gpu, { N }, options);
		torch::Tensor c_gpu_tensor = torch::from_blob(c_gpu, { N }, options);

		cout << "Verifying result using Torch tensors" << endl;
		if (ENABLE_ERROR)
		{
			/*
			TEST
			Change the value of the result should result in two things:
			 - the GPU memory will be modified
			 - the CPU test later on (after the GPU memory is copied to the CPU side) should fail
			*/
			cout << "ERROR GENERATION ENABLED! Application will crash during verification of results" << endl;
			cout << "Changing result first element from " << c_gpu_tensor[0];
			c_gpu_tensor[0] = 99999999;
			cout << " to " << c_gpu_tensor[0] << endl;
		}
		else
		{
                        // Check the result on the GPU by using the provided Torch functionaly
			assert(c_gpu_tensor.equal(a_gpu_tensor.add(b_gpu_tensor)) == true);
		}
	}
	catch (exception& e)
	{
		cout << e.what() << endl;

		cudaFreeHost(a_cpu);
		cudaFreeHost(b_cpu);
		cudaFreeHost(c_cpu);

		cudaFree(a_gpu);
		cudaFree(b_gpu);
		cudaFree(c_gpu);

		return 1;
	}

	// Copy memory to device and also synchronize (implicitly)
	cout << "Synchronizing CPU and GPU. Copying result from GPU to CPU" << endl;
	cudaMemcpy(c_cpu, c_gpu, bytes, cudaMemcpyDeviceToHost);

	// Verify the result on the CPU
	cout << "Verifying result on CPU" << endl;
	for (int i = 0; i < N; ++i)
	{
		assert(c_cpu[i] == a_cpu[i] + b_cpu[i]);
	}

	cudaFreeHost(a_cpu);
	cudaFreeHost(b_cpu);
	cudaFreeHost(c_cpu);

	cudaFree(a_gpu);
	cudaFree(b_gpu);
	cudaFree(c_gpu);

	/*
	double array[] = { 1, 2, 3, 4, 5 };
	auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCUDA, 1);
	torch::Tensor tharray = torch::from_blob(array, { 5 }, options);
	*/

	return 0;
}

with a simple CUDA kernel

__global__ void vector_add_kernel(int* a, int* b, int* c, int N)
{
	// Calculate global thread ID
	int t_id = (blockDim.x * blockIdx.x) + threadIdx.x;

	// Check boundry
	if (t_id < N)
	{
		c[t_id] = a[t_id] + b[t_id];
	}
}

void vector_add(int* a, int* b, int* c, int N, int cuda_grid_size, int cuda_block_size)
{
	vector_add_kernel << <cuda_grid_size, cuda_block_size >> > (a, b, c, N);
	cudaGetLastError();
}

The last thing is in regards to CMake or whatever tool you are using to manage and build your project. It is absolutely critical that you pick the right version of libtorch. Is it debug or is it release build? Is it for CPU only or CPU and GPU? For a setup that would require both running and debugging your project, that means you need to obtain a couple of gigabytes since the CUDA enabled version of libtorch is humongous.

1 Like

See my solution and let me know if it works for you. Perhaps I can help you out.