Libtorch's CPU inference is much slower on Windows than on Linux

I found that PyTorch / LibTorch version 1.10 with certain topologies (classifiers with Fully-Connected/Dense-Layers) is during CPU inference significantly slower on Windows 10 than on Linux.

The model used in the following problem description is created and trained within a Pytorch environment (Python 3.9) and exported using Torch JIT Script. After the trained model has been exported, it is called in a C++ (libtorch).

The average runtimes over a hundred dummy images times are listed below. Multithreading was disabled to perform the measurement.

Windows

218ms

Linux

40ms

The CPUs are not identical but the gap in runtime can’t be explained by the CPUs. In general, the CPU on the Windows machine should perform better on single core tasks.

To reproduce the results, the following Code snippets are required.
Note: The input dimensions in the following examples are 256, 256, 1 (HWC).

Python Script for the model:


import torch.nn as nn
import torch
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24*122*122, num_classes)

    def forward(self, input):
        output = F.relu(self.bn1(self.conv1(input)))      
        output = F.relu(self.bn2(self.conv2(output)))     
        output = self.pool(output)                        
        output = F.relu(self.bn4(self.conv4(output)))     
        output = F.relu(self.bn5(self.conv5(output)))
        #print(output.shape)     
        output = output.view(-1, 24*122*122)
        output = self.fc1(output)
        return output

Python code to export the model:

traced_script_module = torch.jit.trace(model, images)
traced_script_module.save(params["model_path"] + f'CP_epoch{epoch + 1}.pt')

C++ Program to measure the runtime

void test()
{
	at::set_num_threads(1);
	at::init_num_threads();
	torch::jit::script::Module module = torch::jit::load("classifier.pt", c10::DeviceType::CPU);
	module.eval();
	cv::Mat m = cv::Mat::ones(256, 256, CV_8UC1);
	torch::Tensor tensor_image = torch::from_blob(m.data, { m.rows, m.cols, m.channels() }, at::kByte);
	tensor_image = tensor_image.permute({ 2,0,1 }); 
	tensor_image = tensor_image.toType(torch::kFloat32); 
	tensor_image.to(c10::DeviceType::CPU); 
	torch::Tensor output;
	auto start = std::chrono::high_resolution_clock::now();
	int runs = 100;
	for (size_t i = 0; i < runs; i++)
	{
		output = module.forward({ tensor_image }).toTensor().detach();

	}
	auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - start).count();
	std::cout << duration / (float) runs << std::endl;
}

In both cases, PyTorch is linked against Intel’s MKL (oneAPI 2021.3.0). MKLDNN is enabled on both platforms. The compiler on Windows is MSVC 14.29.30133 and on Linux the compiler is gcc (SUSE Linux) 7.5.0. By default, the MKL is linked statically on Windows and dynamically on Linux.

The following instructions were used for the build of PyTorch:

Build Description

[Edit]
I’m on Windows (VS2019, LibTorch 1.130) and I have a similar issue, inference is super slow, but only on the first pass, not subsequent passes. Inference within Python/PyTorch itself does not suffer from this.

My solution, you could try is - After loading my model I run a “dummy” input through it, which setups up the initial CUDA state I’m sure a legend like @ptrblck may know what is actually going on under the hood :slight_smile: . Then when I do my real inference at runtime the inference cost is very cheap :

INFO::2022-12-10::10:42:59::Dummy()::Duration: 3798ms
INFO::2022-12-10::10:42:59::Forward_GPU()::Duration: 4ms

I can’t post any code at the moment due to not graduated (this is dissertation code :wink: ). If I remember I will come back and add it later.

Note the idea of using a “dummy” Input, to eliminate this initial inference cost came from: https://github.com/pytorch/pytorch/issues/44269

Yes, your explanation is correct and warmup iterations are a good idea in case your workload needs to avoid these.
The first iteration(s) are slow due to multiple reasons:

  • The very first CUDA call (it could be a tensor creation etc.) is creating the CUDA context, which loads the driver etc. In older CUDA versions (<11.7) all kernels for your GPU architecture were also directly loaded into the context, which takes time and uses memory. Since CUDA 11.7 we’ve enabled “lazy module loading”, which will only load the called kernel into the context if needed. This will reduce the startup time as well as the memory usage significantly and is enabled by default in our pip wheels and conda binaries using CUDA 11.7.
  • The first iterations of your actual workload need to allocate new memory, which will then be reused through the CUDACachingAllocator. However, the initial cudaMalloc calls are also “expensive” (compared to just reusing the already allocated memory) and you would thus also see a slow iteration time until your workload reached the peak memory and is able to reuse the GPU memory. Note that new cudaMalloc calls could of course still happen during the training e.g. if your input size increases etc.
  • Depending on your use case and if torch.jit.script or torch.compile (in the current nightly releases and soon in PyTorch 2.0) is used, you might also see JIT compiled kernels, optimization passes etc. I believe the optimization needs 3 iterations in torch.jit.script, but don’t know how torch.compile would behave.
  • If you are using conv layers and are allowing cuDNN to benchmark valid kernels and select the fastest one (via torch.backends.cudnn.benchmark = True) the profiling and kernel selection for each new workload (i.e. new input shape, new dtype etc. to the conv layer) will also see an overhead.

As you can see, warmup iterations could help a lot as the first steps are expected to create some overhead.

1 Like

Thank you, I appreciate the detailed response.