Performing Inference with input already in GPU

Ok so I experimented a bit with this, but I got different results when using Mat vs. GpuMat.

Mat code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);
        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);


        // Scale image down
        int scaledown_factor = 256;
        cv::resize(img_float, img_float, cv::Size(img_float.cols / (img_float.rows / (float)scaledown_factor), scaledown_factor), cv::INTER_NEAREST);

        // Emulate transforms.CenterCrop(224)
        cv::Rect roi;
        int new_width = 224;
        int new_height = 224;
        roi.x = img_float.size().width / 2 - new_width / 2;
        roi.width = new_width;
        roi.y = img_float.size().height / 2 - new_height / 2;
        roi.height = new_height;
        cv::Mat img_cropped = img_float(roi);

        // Convert to tensor
        auto img_tensor = torch::from_blob(img_cropped.data, { img_cropped.rows, img_cropped.cols, 3 });
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
        img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
        img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
        auto img_var = torch::autograd::make_variable(img_tensor, false);
        std::cout << img_tensor.sizes() << '\n';

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor.unsqueeze_(0).to(at::kCUDA));

        // forward pass
        torch::Tensor out_tensor = module.forward(inputs).toTensor();

        // print output
        std::cout << out_tensor << '\n';

        // Write tensor to file
        std::ofstream myfile("OutputfromMatProcessing.txt");
        if (myfile.is_open())
        {
            for (int i = 0; i < out_tensor.sizes()[1] - 1; i++) {
                myfile << out_tensor[0][i].item<float_t>() << "\n";
            }
            myfile << out_tensor[0][out_tensor.sizes()[1] - 1].item<float_t>() << "\n";
            myfile.close();
        }
        else std::cout << "Unable to open file";


    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

GpuMat Code:

#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <iostream>
#include <memory>
#include <iomanip>

std::string type2str(int type) {
    std::string r;

    uchar depth = type & CV_MAT_DEPTH_MASK;
    uchar chans = 1 + (type >> CV_CN_SHIFT);

    switch (depth) {
    case CV_8U:  r = "8U"; break;
    case CV_8S:  r = "8S"; break;
    case CV_16U: r = "16U"; break;
    case CV_16S: r = "16S"; break;
    case CV_32S: r = "32S"; break;
    case CV_32F: r = "32F"; break;
    case CV_64F: r = "64F"; break;
    default:     r = "User"; break;
    }

    r += "C";
    r += (chans + '0');

    return r;
}

void deleter(void* arg) {};

int main(int argc, const char* argv[]) {


    // Load up GPU stuff
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);
    torch::data::transforms::Normalize<> normalize_transform({ 0.485, 0.456, 0.406 }, { 0.229, 0.224, 0.225 });
    torch::jit::script::Module module;
    std::cout << "Attempting to load resnet model.." << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("C:\\PyTorchPictureTest\\Model\\traced_resnet_model_pretrained.pt");
        std::cout << "Successfully loaded resnet model" << std::endl;

        module.to(at::kCUDA);
        std::cout << "Moved model to gpu" << std::endl;

        // load image and transform
        cv::Mat image;
        image = cv::imread("C:\\PyTorchPictureTest\\dog.jpg", cv::IMREAD_COLOR);
        cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
        cv::Mat img_float;
        image.convertTo(img_float, CV_32FC3, 1.0f / 255.0f);


        // Scale image down
        int scaledown_factor = 256;
        cv::resize(img_float, img_float, cv::Size(img_float.cols / (img_float.rows / (float)scaledown_factor), scaledown_factor), cv::INTER_NEAREST);

        // Emulate transforms.CenterCrop(224)
        cv::Rect roi;
        int new_width = 224;
        int new_height = 224;
        roi.x = img_float.size().width / 2 - new_width / 2;
        roi.width = new_width;
        roi.y = img_float.size().height / 2 - new_height / 2;
        roi.height = new_height;
        cv::Mat img_cropped = img_float(roi);

        cv::cuda::GpuMat gpu_image;
        gpu_image.upload(img_cropped);

        std::vector<int64_t> dims = { gpu_image.rows, gpu_image.cols, 3 };
        long long step = gpu_image.step / sizeof(float);
        std::vector<int64_t> strides = { step, 3, 1 };
        auto img_tensor = torch::from_blob(gpu_image.data, dims, strides, deleter, torch::kCUDA);

        // Convert to tensor;
        img_tensor = img_tensor.permute({ 2, 0, 1 });
        img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
        img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
        img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
        std::cout << img_tensor.sizes() << '\n';

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(img_tensor.unsqueeze_(0).to(at::kCUDA));

        // forward pass
        torch::Tensor out_tensor = module.forward(inputs).toTensor();

        // print output
        std::cout << out_tensor << '\n';

        // Write tensor to file
        std::ofstream myfile("OutputfromGpuMatProcessing.txt");
        if (myfile.is_open())
        {
            for (int i = 0; i < out_tensor.sizes()[1] - 1; i++) {
                myfile << out_tensor[0][i].item<float_t>() << "\n";
            }
            myfile << out_tensor[0][out_tensor.sizes()[1] - 1].item<float_t>() << "\n";
            myfile.close();
        }
        else std::cout << "Unable to open file";


    }
    catch (const c10::Error & e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "ok\n";
}

However, I do not seem to be getting the same outputs.

Fore example, first 10 outputs in Mat version:
-0.630255
-0.562184
-0.558607
-1.53902
-0.712656
-0.198479
-0.437339
0.591292
0.491775
-0.71214

First 10 outputs in GpuMat version:
-0.598255
-0.585769
-0.610221
-1.58523
-0.771842
-0.215817
-0.574213
0.463584
0.402193
-0.723421

Am I using the from_blob function incorrectly?