Getting _ivalue_ INTERNAL ASSERT FAILED at "../torch/csrc/jit/api/object.h

Hello everybody

I am trying to load a torchscript which i’ve generated with python into c++.

Now, when trying the tutorial with torchvision.resnet it works fine, i can load the model in c++ and forward an input tensor and i get an output.

But trying my custom module does not work.

Maybe you guys have some insights for me that i am missing?

I would be glad for your help

Error:

terminate called after throwing an instance of 'c10::Error'
  what():  _ivalue_ INTERNAL ASSERT FAILED at "../torch/csrc/jit/api/object.h":38, please report a bug to PyTorch. 
Exception raised from _ivalue at ../torch/csrc/jit/api/object.h:38 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f1a16c94d87 in /home/glatzl/miniconda3/envs/bachelor-thesis/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f1a16c45828 in /home/glatzl/miniconda3/envs/bachelor-thesis/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: torch::jit::Object::find_method(std::string const&) const + 0x387 (0x7f1a035159b7 in /home/glatzl/miniconda3/envs/bachelor-thesis/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x215de (0x5621662cf5de in lid_driven_cavity_2d)
frame #4: <unknown function> + 0x8bc2 (0x5621662b6bc2 in lid_driven_cavity_2d)
frame #5: <unknown function> + 0x28150 (0x7f19cba28150 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: __libc_start_main + 0x89 (0x7f19cba28209 in /lib/x86_64-linux-gnu/libc.so.6)
frame #7: <unknown function> + 0x8e05 (0x5621662b6e05 in lid_driven_cavity_2d)

Python

# try implement paper: https://arxiv.org/pdf/2203.11025.pdf
# loss function: Equation (3.7)

import h5py
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
from sklearn.metrics import mean_absolute_error, mean_squared_error

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # Encoder
        self.enc_conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.enc_conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.enc_conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.enc_conv5 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)

        # Bottleneck
        self.bottleneck_conv = nn.Conv2d(1024, 512, kernel_size=3, padding=1)

        # Decoder
        self.dec_conv3 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.dec_conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.dec_conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)

        # Output
        self.final_conv = nn.Conv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        x1 = F.relu(self.enc_conv1(x))
        x2 = F.relu(self.enc_conv2(x1))
        x3 = F.relu(self.enc_conv3(x2))
        x4 = F.relu(self.enc_conv4(x3))
        x5 = F.relu(self.enc_conv5(x4))

        # Bottleneck
        x_bottleneck = F.relu(self.bottleneck_conv(x5))

        # Decoder
        x_dec3 = F.relu(self.dec_conv3(x_bottleneck))
        x_dec2 = F.relu(self.dec_conv2(x_dec3))
        x_dec1 = F.relu(self.dec_conv1(x_dec2))

        # Output
        x_out = F.relu(self.final_conv(x_dec1))


        return x_out
    
def custom_loss(output, target):

    loss = torch.mean(torch.abs(output - target) / (torch.abs(target) + 1))
    return loss

def evaluate_regression(model, test_data, criterion):
    model.eval()
    with torch.no_grad():
        outputs = model(test_data[0])
        loss = criterion(outputs, test_data[1])
        true_data = test_data[1].to('cpu').numpy().flatten()
        predicted_data = outputs.to('cpu').numpy().flatten()
        mae = mean_absolute_error(true_data, predicted_data)
        rmse = np.sqrt(mean_squared_error(true_data, predicted_data))

    return loss.item(), mae, rmse

# Load data
pressure_data = torch.tensor([])
RHS_data = torch.tensor([])

with h5py.File("pressure.h5", "r") as f:
    keys = list(f.keys())
    max_iter = 100
    for key in keys:
        if max_iter == 0:
            break
        max_iter -= 1
        pressure_data = torch.cat((pressure_data, torch.tensor(f[key]).unsqueeze(0)), 0)

with h5py.File("RHS.h5", "r") as f:
    keys = list(f.keys())
    max_iter = 100
    for key in keys:
        if max_iter == 0:
            break
        max_iter -= 1
        RHS_data = torch.cat((RHS_data, torch.tensor(f[key]).unsqueeze(0)), 0)

# Prepare data
pressure_data = pressure_data.view(pressure_data.shape[0], 1, 34, 34).float()
RHS_data = RHS_data.view(RHS_data.shape[0], 1, 34, 34).float()

# Split data into train and test sets
total_samples = pressure_data.shape[0]
train_size = int(0.8 * total_samples)
test_size = total_samples - train_size

train_data, test_data = random_split(list(zip(pressure_data, RHS_data)), [train_size, test_size])

train_pressure_data, train_RHS_data = zip(*train_data)
test_pressure_data, test_RHS_data = zip(*test_data)

train_pressure_data = torch.stack(train_pressure_data).to("cuda")
train_RHS_data = torch.stack(train_RHS_data).to("cuda")
test_pressure_data = torch.stack(test_pressure_data).to("cuda")
test_RHS_data = torch.stack(test_RHS_data).to("cuda")

# Create the model
model = SimpleCNN()
model.to("cuda")

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.000001)

# Set the model in training mode
model.train()

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    # Forward pass
    output = model(train_RHS_data)

    # Calculate the loss
    loss = custom_loss(output, train_pressure_data)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss.item()}")

# Evaluate the model on the test set
test_loss, test_mae, test_rmse = evaluate_regression(model, (test_RHS_data, test_pressure_data), custom_loss)
print(f"Test Loss: {test_loss}, Test MAE: {test_mae}, Test RMSE: {test_rmse}")

# Convert to Torchscript via Annotation
model.eval()
model = model.to("cpu")
test_RHS_data = test_RHS_data.to("cpu")
traced = torch.jit.trace(model, test_RHS_data[0])
traced.save("model.pt")
print(traced.code)

C++

torch::Tensor input = torch::randn({1, 1, 34, 34});

  // Execute the model
  at::Tensor output = this->model.forward({input}).toTensor();

Best regards,