Performing Inference with input already in GPU

Hi all,

I have a trained model in Python but ultimately want to run inference in C++. I followed the procedure to save a serialized model with Torchscript. I then load the model and then put it into the GPU.

I am working in a pipeline where all data resides in the GPU. Now my question is if it is possible to take advantage of the fact that the data is already in the GPU and somehow convert this to a PyTorch GPU tensor, without having to perform and host/device copies. I believe I have found an example of something similar: https://github.com/pytorch/pytorch/issues/19786

The link references a post that shows how to convert a cv::cuda::GpuMat of type 32F to a torch::Tensor. This isn’t exactly what I would like to do as I am not working with cv::cuda::GpuMat types, but the data resides in Cuda memory and I believe the concept is similar. Essentially the method torch::from_blob is used for the conversion.

I have also found another post: Constructing PyTorch's CUDA tensor from C++ with image data already on GPU

However, this post does not seem to reference torch::from_blob. Rather, talks about using pointer references and using the cudaMemCpyAsync method. However, unless I am mistaken I believe this does do a host/device copy.

Am I on the right track here?

Ok so I experimented a bit with this, but I got different results when using Mat vs. GpuMat.

Mat code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);
        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);


        // Scale image down
        int scaledown_factor = 256;
        cv::resize(img_float, img_float, cv::Size(img_float.cols / (img_float.rows / (float)scaledown_factor), scaledown_factor), cv::INTER_NEAREST);

        // Emulate transforms.CenterCrop(224)
        cv::Rect roi;
        int new_width = 224;
        int new_height = 224;
        roi.x = img_float.size().width / 2 - new_width / 2;
        roi.width = new_width;
        roi.y = img_float.size().height / 2 - new_height / 2;
        roi.height = new_height;
        cv::Mat img_cropped = img_float(roi);

        // Convert to tensor
        auto img_tensor = torch::from_blob(img_cropped.data, { img_cropped.rows, img_cropped.cols, 3 });
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
        img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
        img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
        auto img_var = torch::autograd::make_variable(img_tensor, false);
        std::cout << img_tensor.sizes() << '\n';

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor.unsqueeze_(0).to(at::kCUDA));

        // forward pass
        torch::Tensor out_tensor = module.forward(inputs).toTensor();

        // print output
        std::cout << out_tensor << '\n';

        // Write tensor to file
        std::ofstream myfile("OutputfromMatProcessing.txt");
        if (myfile.is_open())
        {
            for (int i = 0; i < out_tensor.sizes()[1] - 1; i++) {
                myfile << out_tensor[0][i].item<float_t>() << "\n";
            }
            myfile << out_tensor[0][out_tensor.sizes()[1] - 1].item<float_t>() << "\n";
            myfile.close();
        }
        else std::cout << "Unable to open file";


    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

GpuMat Code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

void deleter(void* arg) {};

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);
        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);


        // Scale image down
        int scaledown_factor = 256;
        cv::resize(img_float, img_float, cv::Size(img_float.cols / (img_float.rows / (float)scaledown_factor), scaledown_factor), cv::INTER_NEAREST);

        // Emulate transforms.CenterCrop(224)
        cv::Rect roi;
        int new_width = 224;
        int new_height = 224;
        roi.x = img_float.size().width / 2 - new_width / 2;
        roi.width = new_width;
        roi.y = img_float.size().height / 2 - new_height / 2;
        roi.height = new_height;
        cv::Mat img_cropped = img_float(roi);

        cv::cuda::GpuMat gpu_image;
        gpu_image.upload(img_cropped);

        std::vector<int64_t> dims = { gpu_image.rows, gpu_image.cols, 3 };
        long long step = gpu_image.step / sizeof(float);
        std::vector<int64_t> strides = { step, 3, 1 };
        auto img_tensor = torch::from_blob(gpu_image.data, dims, strides, deleter, torch::kCUDA);

        // Convert to tensor;
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
        img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
        img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
        std::cout << img_tensor.sizes() << '\n';

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor.unsqueeze_(0).to(at::kCUDA));

        // forward pass
        torch::Tensor out_tensor = module.forward(inputs).toTensor();

        // print output
        std::cout << out_tensor << '\n';

        // Write tensor to file
        std::ofstream myfile("OutputfromGpuMatProcessing.txt");
        if (myfile.is_open())
        {
            for (int i = 0; i < out_tensor.sizes()[1] - 1; i++) {
                myfile << out_tensor[0][i].item<float_t>() << "\n";
            }
            myfile << out_tensor[0][out_tensor.sizes()[1] - 1].item<float_t>() << "\n";
            myfile.close();
        }
        else std::cout << "Unable to open file";


    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

However, I do not seem to be getting the same outputs.

Fore example, first 10 outputs in Mat version:
-0.630255
-0.562184
-0.558607
-1.53902
-0.712656
-0.198479
-0.437339
0.591292
0.491775
-0.71214

First 10 outputs in GpuMat version:
-0.598255
-0.585769
-0.610221
-1.58523
-0.771842
-0.215817
-0.574213
0.463584
0.402193
-0.723421

Am I using the from_blob function incorrectly?

@tom Do you have any insight on this? Also when you use the from_blob method with the source data already in the GPU, does this change the way the object is referenced, or does it do a copy inside the GPU?

No, all I’m doing is layman epidemiology this week. But so what happens if you cut the model from this and compare the inputs?
Is the result random or is it somehow connected to the image?
The one time I did look at tensors already on the GPU was when I worked on DLPack-based interoperability of CuPy and PyTorch a long time ago. That seemed to work well, so one could look at what DLPack does.

Best regards

Thomas

1 Like