ImageNet classification example in libtorch C++ got cuda argument error


(Tian Jin) #1

I am try converting a python trained model convert it into C++ using trace_model, but got some weired error, the error is

Will load from ../resnet50_typescripts.pt
Model load ok.
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: invalid argument (getDeviceFromPtr at /pytorch/aten/src/ATen/cuda/CUDADevice.h:12)
frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7f2ec76bdfe1 in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libc10.so)
frame #1: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x2a (0x7f2ec76bddfa in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libc10.so)
frame #2: at::CUDATypeDefault::getDeviceFromPtr(void*) const + 0xb1 (0x7f2ec84ed9a1 in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libcaffe2_gpu.so)
frame #3: at::TypeDefault::storageFromBlob(void*, long, std::function<void (void*)> const&) const + 0x33 (0x7f2ef6ce51d3 in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libcaffe2.so)
frame #4: at::TypeDefault::tensorFromBlob(void*, c10::ArrayRef<long>, c10::ArrayRef<long>, std::function<void (void*)> const&) const + 0x85 (0x7f2ef6d32e95 in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libcaffe2.so)
frame #5: at::TypeDefault::tensorFromBlob(void*, c10::ArrayRef<long>, std::function<void (void*)> const&) const + 0x69 (0x7f2ef6cc0ee9 in /media/jintain/sg/ai/tools/tfboys/pt_codes/pt_cpp/libtorch/lib/libcaffe2.so)
frame #6: <unknown function> + 0xb338 (0x565079d4b338 in ./ptcpp)
frame #7: __libc_start_main + 0xe7 (0x7f2ec6d3ab97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #8: <unknown function> + 0xa12a (0x565079d4a12a in ./ptcpp)

[1]    4494 abort (core dumped)  ./ptcpp ../resnet50_typescripts.pt ../dog.png ../synset_words.txt

What’s more, I also provide my simple main.cpp and trace_model.py which I just started from official documentation.

//
// Created by jintain on 11/20/18.
//
/**
 * A simple pytorch classifier inference code test
 *
 * this code will load trained model with python convert
 * it into c++ need model, then load it and inference
 * just for going through all the process deploying pytorch
 * model in C++ production environment
 */



#include "torch/script.h"
#include <torch/script.h>
#include <torch/torch.h>
//#include <torch/Tensor.h>
#include <ATen/Tensor.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>


#include <iostream>
#include <memory>
//#include <vec>


using namespace std;


void load_labels(string label_f, vector<string> labels) {
    ifstream ins(label_f);
    string line;
    while (getline(ins, line)) {
        labels.push_back(line);
    }
}


int main(int argc, const char* argv[]) {

    if (argc != 4) {
        cout << "ptcpp path/to/scripts/model.pt path/to/image.jpg path/to/label.txt\n";
        return -1;
    }

    cout << "Will load from " << argv[1] << endl;
    shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

    if (module == nullptr) {
        cerr << "model load error from " << argv[1] << endl;
    }
    cout << "Model load ok.\n";

    // load image and transform
    cv::Mat image;
    image = cv::imread(argv[2], 1);
    cv::Mat image_resized;
    cv::resize(image, image_resized, cv::Size(224, 224));
    cv::Mat image_resized_float;
    image_resized.convertTo(image_resized_float, CV_32F, 1.0/255);

    auto img_tensor = torch::CUDA(torch::kFloat32).tensorFromBlob(image_resized_float.data, {1, 224, 224, 3});
    cout << "img tensor loaded..\n";
    img_tensor = img_tensor.permute({0, 3, 1, 2});
    img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
    img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
    img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
    auto img_var = torch::autograd::make_variable(img_tensor, false);

    vector<torch::jit::IValue> inputs;
    inputs.push_back(img_var);
    torch::Tensor out_tensor = module->forward(inputs).toTensor();
    cout << out_tensor.slice(1, 0, 10) << '\n';

    // load label
    vector<string> labels;
    load_labels(argv[3], labels);
    cout << "Found all " << labels.size() << " labels.\n";

    // out tensor sort, print the first 2 category
    std::tuple<torch::Tensor,torch::Tensor> result = out_tensor.sort(-1, true);
    torch::Tensor top_scores = std::get<0>(result)[0];
    torch::Tensor top_idxs = std::get<1>(result)[0].toType(torch::kInt32);

    auto top_scores_a = top_scores.accessor<float,1>();
    auto top_idxs_a = top_idxs.accessor<int,1>();

    for (int i = 0; i < 5; ++i) {
        int idx = top_idxs_a[i];
        std::cout << "top-" << i+1 << " label: ";
        std::cout << labels[idx] << ", score: " << top_scores_a[i] << std::endl;
    }

    cv::imshow("image", image);
    cv::waitKey(0);

    return 0;
}

By the way, I post this issue to pytorch github repo but no one response, so I need community help for debug this…

Or, anyone would provide me a program example which can actually run…

the model is too big, but it just resnet50 typical image classificaition model trained with pytorch can download from official site.


#2

After loading the model, convert to CUDA (like this):
module->to(torch::kCUDA);

See if that helps