Slow inference with libtorch on CPUs

Hello,

I am training a model with Pytorch on a A100 GPU and then saving it with Torchscript as a scripted model with

m = torch.jit.script(model)
torch.jit.save(m, model_filename)

I then convert it to a CPU model with

model = torch.load(model_filename)
model.eval()
torch.jit.save(torch.jit.script(model), cpu_model_filename)

and then load it and run inference from my C++ code. However, the inference is really slow, slower than running it with Python on CPUs. I’m using version 2.5.1 for both pytorch and libtorch.

Any insights would be greatly appreciated!

Just for completeness, I’m including my model code and the code that calls model.forward() from C++:

import torch
import torch.nn.functional as F
from torch import nn
from typing import Dict


class LayerNorm2D(nn.LayerNorm):
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)

    def forward(self, x):
        return F.layer_norm(
            x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps
        ).permute(0, 3, 1, 2)


class ConvBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        activation,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.activation = activation()
        self.padding = int(kernel_size / 2)
        self.conv = nn.Conv2d(
            self.in_channels,
            self.out_channels,
            kernel_size=self.kernel_size,
            padding=self.padding,
            padding_mode='reflect'
        )

    def forward(self, x):
        return self.activation(self.conv(x))


class ConvNeXTBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        mid_channels,
        out_channels,
        kernel_size,
        activation,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.activation = activation()
        self.padding = int(kernel_size / 2)
        self.conv1 = nn.Conv2d(
            self.in_channels,
            self.mid_channels,
            kernel_size=self.kernel_size,
            padding=self.padding,
            padding_mode='reflect'
        )
        self.conv2 = nn.Conv2d(
            self.mid_channels,
            self.mid_channels,
            kernel_size=1
        )
        self.conv3 = nn.Conv2d(
            self.mid_channels,
            self.out_channels,
            kernel_size=1,
        )
        self.layer_norm = LayerNorm2D(self.mid_channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.layer_norm(out)
        out = self.conv2(out)
        out = self.activation(out)
        out = self.conv3(out)
        out += identity
        return out


class ConvNeXT(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        bottleneck_dim=64,
        hidden_dim=256,
        kernel_size=5,
        depth=1,
        activation=nn.GELU,
        scalers=DEFAULT_SCALERS,
        pressure_names=None,
        evaptrans_names=None,
        param_names=None,
        n_evaptrans=None,
        parameter_list=None,
        param_nlayer=None
    ):
        super().__init__()
        self.input_channels = in_channels
        self.bottleneck_dim = bottleneck_dim
        self.hidden_dim = hidden_dim
        self.output_channels = out_channels
        self.kernel_size = kernel_size
        self.depth = depth
        self.activation = activation
        self.scalers = scalers
        self.pressure_names = pressure_names
        self.evaptrans_names = evaptrans_names
        self.n_evaptrans = n_evaptrans
        self.param_names = param_names
        self.parameter_list = parameter_list
        self.param_nlayer = param_nlayer

        self.layers = [
            ConvBlock(
                self.input_channels,
                self.hidden_dim,
                1,
                self.activation,
            )
        ]
        for i in range(self.depth):
            self.layers.append(
                ConvNeXTBlock(
                    self.hidden_dim,
                    self.bottleneck_dim,
                    self.hidden_dim,
                    self.kernel_size,
                    self.activation,
                )
            )
        self.layers.append(
            ConvBlock(
                self.hidden_dim,
                self.output_channels,
                1,
                nn.Identity
            )
        )
        self.layers = nn.ModuleList(self.layers)

    @torch.jit.export
    def get_parflow_pressure(self, pressure):
        pressure = pressure.unsqueeze(0)
        self.scale_pressure(pressure)
        return pressure
        
    @torch.jit.export
    def scale_pressure(self, x):
        # Dims are (batch, z, y, x)
        for i in range(x.shape[1]):
            mu = self.scalers[f'press_diff_{i}'][0]
            sigma = self.scalers[f'press_diff_{i}'][1]
            x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
            
    @torch.jit.export
    def unscale_pressure(self, x):
        # Dims are (batch, z, y, x)
        for i in range(x.shape[1]):
            mu = self.scalers[f'press_diff_{i}'][0]
            sigma = self.scalers[f'press_diff_{i}'][1]
            x[:, i, :, :] = x[:, i, :, :] * sigma + mu

    @torch.jit.export
    def get_predicted_pressure(self, x):
        self.unscale_pressure(x)
        return x.squeeze()

    @torch.jit.export
    def get_parflow_evaptrans(self, evaptrans):
        if self.n_evaptrans > 0:
            evaptrans = evaptrans[0:self.n_evaptrans,:,:]
        #Grab the top n_lay layers
        elif self.n_evaptrans < 0:
            evaptrans = evaptrans[self.n_evaptrans:,:,:]
        evaptrans = evaptrans.unsqueeze(0)
        self.scale_evaptrans(evaptrans)
        return evaptrans
    
    @torch.jit.export
    def scale_evaptrans(self, x):
        # Dims are (batch, z, y, x)
        for i, name in enumerate(self.evaptrans_names):
            mu = self.scalers[name][0]
            sigma = self.scalers[name][1]
            x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
        
    @torch.jit.export
    def unscale_evaptrans(self, x):
        # Dims are (batch, z, y, x)
        for i, name in enumerate(self.evaptrans_names):
            mu = self.scalers[name][0]
            sigma = self.scalers[name][1]
            x[:, i, :, :] = x[:, i, :, :] * sigma + mu

    @torch.jit.export
    def get_parflow_statics(self, statics:Dict[str, torch.Tensor]):
        parameter_data = []
        for (parameter, n_lay) in zip(self.parameter_list, self.param_nlayer):
            param_temp = statics[parameter]
            if param_temp.shape[0] > 1:
                #Grab the top n bottom or top layers if specified in the param_nlayer list
                #Grab the bottom n_lay layers
                if n_lay > 0:
                    param_temp = param_temp[0:n_lay,:,:]
                #Grab the top n_lay layers
                elif n_lay < 0:
                    param_temp = param_temp[n_lay:,:,:]
            parameter_data.append(param_temp)

        # Concatenate the parameter data together
        # End result is a dims of (n_parameters, y, x)
        parameter_data = torch.cat(parameter_data, dim=0)
        parameter_data = parameter_data.unsqueeze(0)
        self.scale_statics(parameter_data)
        return parameter_data
            
    @torch.jit.export
    def scale_statics(self, x):
        for i, name in enumerate(self.param_names):
            mu = self.scalers[name][0]
            sigma = self.scalers[name][1]
            x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
    
    @torch.jit.export
    def unscale_statics(self, x):
        for i, name in enumerate(self.param_names):
            mu = self.scalers[name][0]
            sigma = self.scalers[name][1]
            x[:, i, :, :] = x[:, i, :, :] * sigma + mu

    def forward(self, pressure, evaptrans, statics):
        # Concatenate the data
        x = torch.cat([pressure, evaptrans, statics], dim=1)

        for l in self.layers:
            x = l(x)

        return x
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>

using namespace torch::indexing;

static torch::jit::script::Module model;
static torch::Tensor statics;

extern "C" {
  void init_torch_model(char* model_filepath, int nx, int ny, int nz, double *po_dat,
			double *mann_dat, double *slopex_dat, double *slopey_dat, double *permx_dat,
			double *permy_dat, double *permz_dat, double *sres_dat, double *ssat_dat,
			double *fbz_dat, double *specific_storage_dat, double *alpha_dat, double *n_dat,
			int torch_debug) {                           
    std::string model_path = std::string(model_filepath);
    c10::InferenceMode guard;
    try {
      model = torch::jit::load(model_path);
      model.eval();
    }
    catch (const c10::Error& e) {
      throw std::runtime_error(std::string("Failed to load the Torch model:\n") + e.what());
    }

    // Get the true fields without the ghost nodes
    std::unordered_map<std::string, torch::Tensor> statics_map;
    torch::Tensor porosity = torch::from_blob(po_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["porosity"] = porosity;
    torch::Tensor mannings = torch::from_blob(mann_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["mannings"] = mannings;
    torch::Tensor slope_x = torch::from_blob(slopex_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["slope_x"] = slope_x;
    torch::Tensor slope_y = torch::from_blob(slopey_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["slope_y"] = slope_y;
    torch::Tensor perm_x = torch::from_blob(permx_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["perm_x"] = perm_x;
    torch::Tensor perm_y = torch::from_blob(permy_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["perm_y"] = perm_y;
    torch::Tensor perm_z = torch::from_blob(permz_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["perm_z"] = perm_z;
    torch::Tensor sres = torch::from_blob(sres_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["sres"] = sres;
    torch::Tensor ssat = torch::from_blob(ssat_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["ssat"] = ssat;
    torch::Tensor fbz = torch::from_blob(fbz_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["pf_flowbarrier"] = fbz;
    torch::Tensor specific_storage = torch::from_blob(specific_storage_dat,
						      {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["specific_storage"] = specific_storage;
    torch::Tensor alpha = torch::from_blob(alpha_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["alpha"] = alpha;
    torch::Tensor n = torch::from_blob(n_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    statics_map["n"] = n;

    statics = model.run_method("get_parflow_statics", statics_map).toTensor();
    if (torch_debug) {
      torch::save(statics, "scaled_statics.pt");
    }
  }
  
  double* predict_next_pressure_step(double* pp, double* et, int nx, int ny, int nz, int file_number) {
    c10::InferenceMode guard;
    torch::Tensor press = torch::from_blob(pp, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    torch::Tensor evaptrans = torch::from_blob(et, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
    press = model.run_method("get_parflow_pressure", press).toTensor();
    evaptrans = model.run_method("get_parflow_evaptrans", evaptrans).toTensor();

    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(press);
    inputs.push_back(evaptrans);
    inputs.push_back(statics);
    torch::Tensor output = model.forward(inputs).toTensor();
    torch::Tensor model_output = model.run_method("get_predicted_pressure", output).toTensor();
    torch::Tensor predicted_pressure = torch::from_blob(pp, {nz, ny, nx}, torch::kDouble);
    predicted_pressure.index_put_({Slice(1, nz-1), Slice(1, ny-1), Slice(1, nx-1)}, model_output);
    
    if (!predicted_pressure.is_contiguous()) {
      predicted_pressure = predicted_pressure.contiguous();
    }

    double* predicted_pressure_array = predicted_pressure.data_ptr<double>();

    // Copy pressure data back to the pressure field
    if (predicted_pressure_array != pp) {
      std::size_t sz = nx * ny * nz;
      std::copy(predicted_pressure_array, predicted_pressure_array + sz, pp);
    }
    return pp;
  }
}

I had a similar phenomenon happening when I compiled my model using ExecuTorch. After compiling for a CPU backend, I noticed super slow inference. Are you using an optimized backend at all?

In my case, I cross-compiled XNNPACK and linked those libs with my project. I ended up getting a 30x boost in inference performance (latency dropped from 1500 ms → 50ms on a Cortex A-53).

Thanks for the reply and the suggestion!

I am not using any optimized backends. I downloaded the CPU version of libtorch from the pytorch website and I assumed that it would be optimized for a production environment.

I did some profiling and I noticed that the code calls a function “_slow_conv2d_forward()”. Each inference step spends >90% of the time in that function. The name drew my attention, it turns out the code is not calling any optimized versions for the 2D convolution for some reason (I don’t yet know why).

I will try your suggestion and report back. Any more insights/suggestions are welcome, thanks!