MNIST with pytorch c++ api

Hello together,

I’am trying around with the newly released c++ api of pytorch. It sounds pretty promising and I think it’s definitely going in the right direction.

However as you can imagine I encountered some problems I couldn’t solve so far. First things first:

  1. I used pytorch (python) to train an MNIST model. Nothing special here 2xconv2d + dropout + 2xlinear.

  2. I used the concept of Torch Script to save my model and to be able to load it later in c++.

   use_cuda = torch.cuda.is_available()
    mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)
    train_image, train_target= mnist_testset[24]
    print (type(train_image))

    train_image.show()

    device = torch.device("cuda" if use_cuda else "cpu")
    model = torch.load("model.pth").to(device)
    loader = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])
    tensor_image = loader(train_image).unsqueeze(0).to(device)
    output = model(tensor_image)
    pred = output.max(1, keepdim=True)[1]
    pred = torch.squeeze(pred)

    print("Success - Train target: " + str(train_target.cpu().numpy()) + " Prediction: " + str(pred.cpu().numpy()))

    # TRACING THE MODEL
    traced_net = torch.jit.trace(model,tensor_image)
    traced_net.save("model_trace.pt")

where model is my pytorch model and tensor_image is an example input which is necessary for tracing. The result of this is a model_trace.pt file that can be loaded from c++.

Alright so far so good!

  1. Next thing I wanted to do is to run the model in C++ so I can do the forward of a sample MNIST image in C++.

Most importantly include torch!

#include <torch/script.h> // One-stop header.

Now I was able to load the model in c++ using:

std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/home/Dev/testtorch/model_trace.pt");
  1. After the model was loaded succesfully I needed some input to try it out! So I created a random Tensor as Input. Later I will have to change it so I can read an Image from OpenCV, access the data and convert it to a Tensor but for now I just wanted to try out If my Model accepts any Input.
at::Tensor randomInput = at::rand({1,28,28});
  1. I used the newly created “randomInput” to call the forward function of my model.
module->forward({randomInput})

but here comes the problem. The Building works but I get some runtime errors.

terminate called after throwing an instance of 'at::Error'
  what():  Tensor that was converted to Variable was not actually a Variable (Variable at /pytorch/torch/csrc/autograd/variable.h:120)
frame #0: <unknown function> + 0x483fef (0x7fb619733fef in /home/narvis/Lib/libtorch/lib/libtorch.so.1)
frame #1: <unknown function> + 0x4842a1 (0x7fb6197342a1 in /home/narvis/Lib/libtorch/lib/libtorch.so.1)
frame #2: <unknown function> + 0x4886aa (0x7fb6197386aa in /home/narvis/Lib/libtorch/lib/libtorch.so.1)
frame #3: torch::jit::script::Method::run(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&) + 0xf6 (0x42fbf8 in /home/narvis/Dev/testtorch/cmake-build-debug/example-app)
frame #4: torch::jit::script::Method::operator()(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >) + 0x4a (0x42fc86 in /home/narvis/Dev/testtorch/cmake-build-debug/example-app)
frame #5: torch::jit::script::Module::forward(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >) + 0x81 (0x430495 in /home/narvis/Dev/testtorch/cmake-build-debug/example-app)
frame #6: main + 0x548 (0x42c1a0 in /home/narvis/Dev/testtorch/cmake-build-debug/example-app)
frame #7: __libc_start_main + 0xf0 (0x7fb60eae6830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #8: _start + 0x29 (0x42b209 in /home/narvis/Dev/testtorch/cmake-build-debug/example-app)

I guess that this error has something to do with the ATen Tensor library but I couldn’t figure out what the problem was. If you have any thoughts on this please share them with me!

Hi,

I’m not sure but I think it expects torch::Tensor as input no?
Be careful as at::Tensor and torch::Tensor are not the same :wink:

You are absolutely right!

Christian S. Perone wrote a nice article about it:

A word of caution for those who are starting now is to be careful with the use of the tensors that can be created both from ATen and autograd, do not mix them , the ATen will return the plain tensors (when you create them using the at namespace) while the autograd functions (from the torch namespace) will return Variable , by adding its automatic differentiation mechanism.

So if you use:

at::Tensor test = at::rand({1,28,28})

gives you a at::Tensor without autograd functionality.

at::Tensor test = tensor::rand({1,28,28})

gives you a at::Tensor with autograd functionality.

Thanks again!

1 Like

I wrote a minimal example on how to train a model in pytorch using python and use c++ to load the model and use it. I hope this helps!