Values after softmax in torchscript in c++ is not same as python torch

#include <torch/torch.h>
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>

using namespace cv;
using namespace std;

int main() {
    
    std::cout << std::fixed << std::setprecision(4);
    torch::jit::script::Module module = torch::jit::load("/data2/aditya/classifiers/LHSRHS/lhsrhs_script.zip");
    torch::data::transforms::Normalize<> normalize_transform({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225});
    
        
    Mat image_bgr, image, image1;
    image_bgr = imread("/data2/aditya/classifiers/LHSRHS/new_data801/RHS/100_5751_rhs_front_door_assy.png");
    cvtColor(image_bgr, image, COLOR_BGR2RGB);
    resize(image, image1, {448, 448});
    
//     cout<<image1<<'\n';
    
    torch::Tensor tensor_image = torch::from_blob(image1.data, {3, image1.rows, image1.cols}, at::kByte);
    tensor_image = tensor_image.to(torch::kFloat32).div_(255);
    cout<<tensor_image.sizes()<<'\n';
    
    module.eval();
    torch::Tensor tensor_image1 = normalize_transform(tensor_image).unsqueeze_(0);

    std::vector<torch::jit::IValue> input;
    input.push_back(tensor_image1);
    
    at::Tensor output = module.forward(tensor_image1).toTensor();
    
    at::Tensor output1 = torch::softmax(output, 1);
    
    std::cout<<output1<<'\n';
    
}
    

I trained a binary classifier in fastai and then using torch.jit.trace, I converted the model to torchscript and then I am loading the image and model in c++. But the predictions and the values which I am getting after softmax in c++ are very different and inaccurate as compared to torch in python. And I am not getting what is the problem. So, please check where I am going wrong.

Where is the python code?
Are you sure you are using the same transforms in both cases?

1 Like

@dambo thank you for your response.
Yes, I am sure. The code for inference which I am using in python is:-

learn = load_learner("/data2/aditya/classifiers/LHSRHS/new_data801/", "final_lhs_rhs_model.pkl")
my_model = learn.model.cpu()
softmaxer = torch.nn.Softmax(dim=1)
my_model.eval()
image = Image.open("/data2/aditya/classifiers/LHSRHS/new_data801/RHS/100_5751_rhs_front_door_assy.png")
image = image.resize((448, 448))
x = TTF.to_tensor(image)
x = TTF.normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
x.unsqueeze_(0)
print(x.size())
raw_out = my_model(x)
print(softmaxer(raw_out))

traced_cell = torch.jit.trace(my_model, x)
traced_cell.save('lhsrhs_script.zip')

I think there is some problem with normalize_transform function in c++ but I don’t know what it is because when I am printing tensor_image1 in c++ the values are not coming as expected.

If I’m not mistaken, OpenCV reads the image as [height, width, channels], while PIL.Image returns the array as [channels, height, width].
The torch::from_blob call might thus interleave the pixels. Have you checked the outputs?

Also, resize should use the linear interpolation by default in OpenCV while PIL uses nearest neighbors by default.

Could you check the results separately?

1 Like

Try reading your image like this:

For the sake of testing, DO NOT use normalisation either in Python nor C++ and see if you get teh same results.

2 Likes

@ptrblck @dambo Thanks for your answer. Difference in values after softmax was due to error in image loading. Now, I am getting approximately same values after softmax in python and C++ both.

@Aditya_Kumar what was the specific source of the problem? Were you able to get the same output with normalization? I am having a problem getting the same output and I am trying to find the source of my problem.

@solarflarefx There was some problem with image loading, I was loading image using OpenCV and OpenCV reads the image as [height, width, channels] and but for torch::from_blob it should be [channels, height, width]. But after correcting this I was getting, approximately same result. Yes, I was using normalization and the getting same answer without normalization I couldn’t get. If you can post your code then I think I can help you with resolving your query.

@Aditya_Kumar I see. I thought I accounted for this in my code. But I am putting it below just to make sure. I am simply loading the ResNet18 model, and loading a sample image: https://raw.githubusercontent.com/pytorch/hub/master/dog.jpg

Python Code:

import torch
import torchvision.models as models
import urllib.request
from PIL import Image
from torchvision import transforms

def norm_chan(chan, mean, std):
    b = (chan - mean) / std
    return b

# Load resnet18 model
model = models.resnet18(pretrained = True)

# Download an example image from the pytorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)

# sample execution (requires torchvision)
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)

import numpy as np
input_tensor_numpy = input_tensor.numpy()
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model


# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

import numpy as np
print()
output_cpu = output[0].cpu()
output_numpy = output_cpu.numpy()
print(output_numpy[0:9])

C++ Code:

#include <torch/script.h>
#include <torch/torch.h>
#include <ATen/Tensor.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);

    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", 1);

        cv::cvtColor(image, image, cv::COLOR_BGR2RGB); 
        cv::Mat img_float;
        image.convertTo(img_float, CV_32F, 1.0 / 255);
        cv::resize(img_float, img_float, cv::Size(224, 224), cv::INTER_NEAREST);

        auto img_tensor = torch::from_blob(img_float.data, { 1, 224, 224, 3 }).to(torch::kCUDA);
        img_tensor = img_tensor.permute({ 0, 3, 1, 2 });
        img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
        img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
        img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
        auto img_var = torch::autograd::make_variable(img_tensor, false);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_var);
        torch::Tensor out_tensor = module.forward(inputs).toTensor();
        std::cout << out_tensor.slice(1, 0, 10) << '\n';
    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

Is there a good way to probe intermediate points in the code? It’s easy to print tensor slices in Python. How would I do this in C++? Would it be using the narrow method? Maybe I could do it on my img_tensor in C++:

auto img_tensor_cpu = img_tensor.to(torch::kCPU);
auto img_tensor_cpu_sliced = img_tensor_cpu.narrow(2, 0, 1);
std::cout << '\n';
std::cout << img_tensor_cpu_sliced << '\n';

I believe the equivalent in my Python code would be:

print(torch.narrow(input_tensor, 2, 0, 1))

I don’t be getting the same values here.

I think you are doing wrong in this line auto img_tensor = torch::from_blob(img_float.data, { 1, 224, 224, 3 }).to(torch::kCUDA);. Because before this dimensions of your image are still [height, width, channels] then how you added an extra dimension in { 1, 224, 224, 3 }, so first load your image by replacing it with only {224, 224, 3} and then before passing image to model just use this unsqueeze_(0). For further reference, you can see my code here.

You can also torchscript transform template also in c++.

@Aditya_Kumar
Thanks for your reply. So I tried following your code structure:

#include <torch/script.h>
#include <torch/torch.h>
#include <ATen/Tensor.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);

        std::cout << image.size << "\n" << std::endl;
        std::cout << image.cols << "\n" << std::endl;
        std::cout << image.rows << "\n" << std::endl;
        std::cout << type2str(image.type()) << "\n" << std::endl;

        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);
        cv::resize(img_float, img_float, cv::Size(224, 224), cv::INTER_NEAREST);

        auto img_tensor = torch::from_blob(img_float.data, { img_float.rows, img_float.cols, 3 }).to(torch::kCUDA);
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        std::cout << img_tensor.sizes() << '\n';
        
        torch::Tensor img_tensor_norm = normalize_transform(img_tensor).unsqueeze_(0);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor_norm.to(at::kCUDA));
        torch::Tensor out_tensor = module.forward(inputs).toTensor();
        std::cout << out_tensor << '\n';
    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

I seem to be getting an error on this line:

torch::Tensor img_tensor_norm = normalize_transform(img_tensor).unsqueeze_(0);

Error:
<Information not available, no symbols loaded for c10.dll>
“expected device cuda:0 but got device cpu”

Sorry my fault. The problem line was:

auto img_tensor = torch::from_blob(img_float.data, { img_float.rows, img_float.cols, 3 }).to(torch::kCUDA);

I changed this to:

auto img_tensor = torch::from_blob(img_float.data, { img_float.rows, img_float.cols, 3 });

This prevented the error from happening. However, I still do not get the same output.

Python:

Cpp:

I suppose at this point it could be that my input tensors are not the same, or that I have made an error exporting to Torchscript (which I doubt, since I simply followed the example in the official documentation).

Ok, so I believe I found my source of error. In Python before I serialized and saved the model, I forgot to specify to load the pretrained network.

Updated Python Code:

import torch
import torchvision.models as models
import urllib.request
from PIL import Image
from torchvision import transforms
import numpy as np

def norm_chan(chan, mean, std):
    b = (chan - mean) / std
    return b

# Load resnet18 model
model = models.resnet18(pretrained = True)

serialize_model = True
if (serialize_model):
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save("traced_resnet_model_pretrained.pt")

# sample execution (requires torchvision)
input_image = Image.open("C:\\PyTorchPictureTest\\dog.jpg")
np_im = np.array(input_image)

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

import numpy as np
print()
output_cpu = output[0].cpu()
output_numpy = output_cpu.numpy()
for i in range(0, len(output_numpy)):
    if ( i%10 == 0):
        print("\n")
    print(str(output_numpy[i]), end = " ")
print()

Cpp Code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);

        std::cout << image.size << "\n" << std::endl;
        std::cout << image.cols << "\n" << std::endl;
        std::cout << image.rows << "\n" << std::endl;
        std::cout << type2str(image.type()) << "\n" << std::endl;

        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);
        cv::resize(img_float, img_float, cv::Size(224, 224), cv::INTER_NEAREST);

        auto img_tensor = torch::from_blob(img_float.data, { img_float.rows, img_float.cols, 3 });
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        std::cout << img_tensor.sizes() << '\n';
        
        torch::Tensor img_tensor_norm = normalize_transform(img_tensor).unsqueeze_(0);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor_norm.to(at::kCUDA));
        torch::Tensor out_tensor = module.forward(inputs).toTensor();
        
        std::cout << out_tensor << '\n';

    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

Python Output:

Cpp Output:

So the outputs are similar, but you can see the differences. Is this what you would expect? Though the first point seems to be quite off (-0.7060 in Cpp, -0.575 in Python). What is the source of the difference? Does it lie in the serialization of the model? Or does it have to do with pillow versus opencv?

By the way, is there an easy way to compare corresponding slices of the input tensor in C++ and Python?

I think first you should be doing predictions in evaluation mode that is model.eval() which will deactivate batchnorm and dropout layers in both c++ and python. In your python code, you are using a transform transforms.CenterCrop which you are not doing in c++. And there should not be large difference in answers between cpp and python. For Example, if cpp gives 0.75 then python should be in range 0.74-0.76.
Why do you want to compare input tensor?

@Aditya_Kumar
Thanks for your reply. Would the evaluation mode cause a difference in the output, or does it have to do with speed and efficiency?

So you are right about the cropping. Since I could not find an equivalent of transforms.CenterCrop(224), I attempted to do the equivalent of this in OpenCV.

Python Code (same as before, haven’t implemented your eval mode suggestion yet):

import torch
import torchvision.models as models
import urllib.request
from PIL import Image
from torchvision import transforms
import numpy as np

def norm_chan(chan, mean, std):
    b = (chan - mean) / std
    return b

# Load resnet18 model
model = models.resnet18(pretrained = True)

serialize_model = True
if (serialize_model):
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save("traced_resnet_model_pretrained.pt")

# sample execution (requires torchvision)
input_image = Image.open("C:\\PyTorchPictureTest\\dog.jpg")
np_im = np.array(input_image)

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

import numpy as np
print()
output_cpu = output[0].cpu()
output_numpy = output_cpu.numpy()
for i in range(0, len(output_numpy)):
    if ( i%10 == 0):
        print("\n")
    print(str(output_numpy[i]), end = " ")
print()

C++ code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);
        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);

        // Scale image down
        int scaledown_factor = 256;
        cv::resize(img_float, img_float, cv::Size(img_float.cols / (img_float.rows / (float)scaledown_factor), scaledown_factor), cv::INTER_NEAREST);

        // Emulate transforms.CenterCrop(224)
        cv::Rect roi;
        int new_width = 224;
        int new_height = 224;
        roi.x = img_float.size().width / 2 - new_width / 2;
        roi.width = new_width;
        roi.y = img_float.size().height / 2 - new_height / 2;
        roi.height = new_height;
        cv::Mat img_cropped = img_float(roi);

        // Convert to tensor
        auto img_tensor = torch::from_blob(img_cropped.data, { img_cropped.rows, img_cropped.cols, 3 });
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        std::cout << img_tensor.sizes() << '\n';
        
        // normalize
        torch::Tensor img_tensor_norm = normalize_transform(img_tensor).unsqueeze_(0);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor_norm.to(at::kCUDA));

        // forward pass
        torch::Tensor out_tensor = module.forward(inputs).toTensor();
        
        // print output
        std::cout << out_tensor << '\n';

    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

However, even with my changes, there is still a larger different in my outputs than you have suggested should be the case.

Python Output:

C++ Output:

Regarding the comparing of input tensors, I more wanted to debug the cause of the difference in my outputs: whether it was due to there being a difference of the data going into the model, or whether the model itself produces a different output.

Perhaps I am overcomplicating things a bit for myself. I am relatively new to torchscript and I was trying to experiment with it to ensure that the output in C++ is equivalent to the output in Python. I chose to use ResNet18 since that is the example that was shown on PyTorch’s documentation. Would I be better off doing this exercise with a different network?

Hi,

I ran to the same problem! Did you find any solution? please.