Hello,
I am training a model with Pytorch on a A100 GPU and then saving it with Torchscript as a scripted model with
m = torch.jit.script(model)
torch.jit.save(m, model_filename)
I then convert it to a CPU model with
model = torch.load(model_filename)
model.eval()
torch.jit.save(torch.jit.script(model), cpu_model_filename)
and then load it and run inference from my C++ code. However, the inference is really slow, slower than running it with Python on CPUs. I’m using version 2.5.1 for both pytorch and libtorch.
Any insights would be greatly appreciated!
Just for completeness, I’m including my model code and the code that calls model.forward() from C++:
import torch
import torch.nn.functional as F
from torch import nn
from typing import Dict
class LayerNorm2D(nn.LayerNorm):
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
def forward(self, x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps
).permute(0, 3, 1, 2)
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
activation,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.activation = activation()
self.padding = int(kernel_size / 2)
self.conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
padding=self.padding,
padding_mode='reflect'
)
def forward(self, x):
return self.activation(self.conv(x))
class ConvNeXTBlock(nn.Module):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
kernel_size,
activation,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.activation = activation()
self.padding = int(kernel_size / 2)
self.conv1 = nn.Conv2d(
self.in_channels,
self.mid_channels,
kernel_size=self.kernel_size,
padding=self.padding,
padding_mode='reflect'
)
self.conv2 = nn.Conv2d(
self.mid_channels,
self.mid_channels,
kernel_size=1
)
self.conv3 = nn.Conv2d(
self.mid_channels,
self.out_channels,
kernel_size=1,
)
self.layer_norm = LayerNorm2D(self.mid_channels)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.layer_norm(out)
out = self.conv2(out)
out = self.activation(out)
out = self.conv3(out)
out += identity
return out
class ConvNeXT(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
bottleneck_dim=64,
hidden_dim=256,
kernel_size=5,
depth=1,
activation=nn.GELU,
scalers=DEFAULT_SCALERS,
pressure_names=None,
evaptrans_names=None,
param_names=None,
n_evaptrans=None,
parameter_list=None,
param_nlayer=None
):
super().__init__()
self.input_channels = in_channels
self.bottleneck_dim = bottleneck_dim
self.hidden_dim = hidden_dim
self.output_channels = out_channels
self.kernel_size = kernel_size
self.depth = depth
self.activation = activation
self.scalers = scalers
self.pressure_names = pressure_names
self.evaptrans_names = evaptrans_names
self.n_evaptrans = n_evaptrans
self.param_names = param_names
self.parameter_list = parameter_list
self.param_nlayer = param_nlayer
self.layers = [
ConvBlock(
self.input_channels,
self.hidden_dim,
1,
self.activation,
)
]
for i in range(self.depth):
self.layers.append(
ConvNeXTBlock(
self.hidden_dim,
self.bottleneck_dim,
self.hidden_dim,
self.kernel_size,
self.activation,
)
)
self.layers.append(
ConvBlock(
self.hidden_dim,
self.output_channels,
1,
nn.Identity
)
)
self.layers = nn.ModuleList(self.layers)
@torch.jit.export
def get_parflow_pressure(self, pressure):
pressure = pressure.unsqueeze(0)
self.scale_pressure(pressure)
return pressure
@torch.jit.export
def scale_pressure(self, x):
# Dims are (batch, z, y, x)
for i in range(x.shape[1]):
mu = self.scalers[f'press_diff_{i}'][0]
sigma = self.scalers[f'press_diff_{i}'][1]
x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
@torch.jit.export
def unscale_pressure(self, x):
# Dims are (batch, z, y, x)
for i in range(x.shape[1]):
mu = self.scalers[f'press_diff_{i}'][0]
sigma = self.scalers[f'press_diff_{i}'][1]
x[:, i, :, :] = x[:, i, :, :] * sigma + mu
@torch.jit.export
def get_predicted_pressure(self, x):
self.unscale_pressure(x)
return x.squeeze()
@torch.jit.export
def get_parflow_evaptrans(self, evaptrans):
if self.n_evaptrans > 0:
evaptrans = evaptrans[0:self.n_evaptrans,:,:]
#Grab the top n_lay layers
elif self.n_evaptrans < 0:
evaptrans = evaptrans[self.n_evaptrans:,:,:]
evaptrans = evaptrans.unsqueeze(0)
self.scale_evaptrans(evaptrans)
return evaptrans
@torch.jit.export
def scale_evaptrans(self, x):
# Dims are (batch, z, y, x)
for i, name in enumerate(self.evaptrans_names):
mu = self.scalers[name][0]
sigma = self.scalers[name][1]
x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
@torch.jit.export
def unscale_evaptrans(self, x):
# Dims are (batch, z, y, x)
for i, name in enumerate(self.evaptrans_names):
mu = self.scalers[name][0]
sigma = self.scalers[name][1]
x[:, i, :, :] = x[:, i, :, :] * sigma + mu
@torch.jit.export
def get_parflow_statics(self, statics:Dict[str, torch.Tensor]):
parameter_data = []
for (parameter, n_lay) in zip(self.parameter_list, self.param_nlayer):
param_temp = statics[parameter]
if param_temp.shape[0] > 1:
#Grab the top n bottom or top layers if specified in the param_nlayer list
#Grab the bottom n_lay layers
if n_lay > 0:
param_temp = param_temp[0:n_lay,:,:]
#Grab the top n_lay layers
elif n_lay < 0:
param_temp = param_temp[n_lay:,:,:]
parameter_data.append(param_temp)
# Concatenate the parameter data together
# End result is a dims of (n_parameters, y, x)
parameter_data = torch.cat(parameter_data, dim=0)
parameter_data = parameter_data.unsqueeze(0)
self.scale_statics(parameter_data)
return parameter_data
@torch.jit.export
def scale_statics(self, x):
for i, name in enumerate(self.param_names):
mu = self.scalers[name][0]
sigma = self.scalers[name][1]
x[:, i, :, :] = (x[:, i, :, :] - mu) / sigma
@torch.jit.export
def unscale_statics(self, x):
for i, name in enumerate(self.param_names):
mu = self.scalers[name][0]
sigma = self.scalers[name][1]
x[:, i, :, :] = x[:, i, :, :] * sigma + mu
def forward(self, pressure, evaptrans, statics):
# Concatenate the data
x = torch.cat([pressure, evaptrans, statics], dim=1)
for l in self.layers:
x = l(x)
return x
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
using namespace torch::indexing;
static torch::jit::script::Module model;
static torch::Tensor statics;
extern "C" {
void init_torch_model(char* model_filepath, int nx, int ny, int nz, double *po_dat,
double *mann_dat, double *slopex_dat, double *slopey_dat, double *permx_dat,
double *permy_dat, double *permz_dat, double *sres_dat, double *ssat_dat,
double *fbz_dat, double *specific_storage_dat, double *alpha_dat, double *n_dat,
int torch_debug) {
std::string model_path = std::string(model_filepath);
c10::InferenceMode guard;
try {
model = torch::jit::load(model_path);
model.eval();
}
catch (const c10::Error& e) {
throw std::runtime_error(std::string("Failed to load the Torch model:\n") + e.what());
}
// Get the true fields without the ghost nodes
std::unordered_map<std::string, torch::Tensor> statics_map;
torch::Tensor porosity = torch::from_blob(po_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["porosity"] = porosity;
torch::Tensor mannings = torch::from_blob(mann_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["mannings"] = mannings;
torch::Tensor slope_x = torch::from_blob(slopex_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["slope_x"] = slope_x;
torch::Tensor slope_y = torch::from_blob(slopey_dat, {3, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["slope_y"] = slope_y;
torch::Tensor perm_x = torch::from_blob(permx_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["perm_x"] = perm_x;
torch::Tensor perm_y = torch::from_blob(permy_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["perm_y"] = perm_y;
torch::Tensor perm_z = torch::from_blob(permz_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["perm_z"] = perm_z;
torch::Tensor sres = torch::from_blob(sres_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["sres"] = sres;
torch::Tensor ssat = torch::from_blob(ssat_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["ssat"] = ssat;
torch::Tensor fbz = torch::from_blob(fbz_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["pf_flowbarrier"] = fbz;
torch::Tensor specific_storage = torch::from_blob(specific_storage_dat,
{nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["specific_storage"] = specific_storage;
torch::Tensor alpha = torch::from_blob(alpha_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["alpha"] = alpha;
torch::Tensor n = torch::from_blob(n_dat, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
statics_map["n"] = n;
statics = model.run_method("get_parflow_statics", statics_map).toTensor();
if (torch_debug) {
torch::save(statics, "scaled_statics.pt");
}
}
double* predict_next_pressure_step(double* pp, double* et, int nx, int ny, int nz, int file_number) {
c10::InferenceMode guard;
torch::Tensor press = torch::from_blob(pp, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
torch::Tensor evaptrans = torch::from_blob(et, {nz, ny, nx}, torch::kDouble).index({Slice(1, -1), Slice(1, -1), Slice(1, -1)}).clone();
press = model.run_method("get_parflow_pressure", press).toTensor();
evaptrans = model.run_method("get_parflow_evaptrans", evaptrans).toTensor();
std::vector<torch::jit::IValue> inputs;
inputs.push_back(press);
inputs.push_back(evaptrans);
inputs.push_back(statics);
torch::Tensor output = model.forward(inputs).toTensor();
torch::Tensor model_output = model.run_method("get_predicted_pressure", output).toTensor();
torch::Tensor predicted_pressure = torch::from_blob(pp, {nz, ny, nx}, torch::kDouble);
predicted_pressure.index_put_({Slice(1, nz-1), Slice(1, ny-1), Slice(1, nx-1)}, model_output);
if (!predicted_pressure.is_contiguous()) {
predicted_pressure = predicted_pressure.contiguous();
}
double* predicted_pressure_array = predicted_pressure.data_ptr<double>();
// Copy pressure data back to the pressure field
if (predicted_pressure_array != pp) {
std::size_t sz = nx * ny * nz;
std::copy(predicted_pressure_array, predicted_pressure_array + sz, pp);
}
return pp;
}
}