AssertionError running a pytorch model in C++ with libtorch

I’m following the tutorial here, and after getting it to load the model I wanted to run it. I had to make some small adjustments to the example code as my input features are indices (and just on the side, I could not find any documentation for torch::ones, expected it to be here but didn’t find it, and trying to use torch::kInt64 resulted in “is not a member of ‘torch’” despite it being listed on that page?).

I got it to compile with this

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>
#include <vector>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1; 
  }

  // Deserialize the ScriptModule from a file using torch::jit::load().
  std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

  assert(module != nullptr);

  std::vector<torch::jit::IValue> inputs;
  at::TensorOptions tens_opts;
  tens_opts.dtype(at::kLong);
  inputs.push_back(torch::ones({1, 2}, tens_opts));

  auto output = module->forward(inputs).toTensor();

  std::cout << output.slice(0) << std::endl;;
}

But when running it I get an error:

terminate called after throwing an instance of 'at::Error'
  what():  isTensor() ASSERT FAILED at [...]/libtorch/include/ATen/core/ivalue.h:153, please report a bug to PyTorch. (toTensor at [...]/libtorch/include/ATen/core/ivalue.h:153)

I’m doing slicing in the forward definition of my model, could that be an issue? Here’s my model definition

class CRNN(nn.Module):
    def __init__(self, num_inp, num_hid, num_ff, num_layers, num_out):
        super(CRNN, self).__init__()

        self.embed = nn.Embedding(num_inp, num_hid)
        self.rnn = nn.GRU(num_hid, num_hid, num_layers, batch_first=True)
        self.fc_emb_skip = nn.Linear(num_hid, num_ff)
        self.fc1 = nn.Linear(num_hid + num_ff, num_ff)
        self.fc2 = nn.Linear(num_ff, 3)
        self.hidden_init = nn.Parameter(t.randn(num_layers, 1, num_hid).type(t.FloatTensor), requires_grad=True)
        self.num_inp = num_inp
        self.num_hid = num_hid
        self.num_out = num_out
        self.num_layers = num_layers

    def forward(self, x):
        hidden = self.init_hidden(x.size(0))
        emb = self.embed(x)
        output, hidden = self.rnn(emb, hidden)
        y_skip = F.elu(self.fc_emb_skip(emb[:, :-1]))
        joined_out = t.cat((y_skip, output[:, 1:],), dim=2)
        outview = joined_out.contiguous().view(joined_out.size(0) * joined_out.size(1), joined_out.size(2))
        y = F.elu(self.fc1(outview))
        logprobs = F.log_softmax(self.fc2(y), 1)
        return logprobs, hidden

    def init_hidden(self, batch_size):
        return self.hidden_init.repeat(1, batch_size, 1)

is module->forward(inputs) returning a tensor, or a tuple of tensors? You might need to unpack the tuple before calling toTensor() from something it contained.

1 Like

Oh yeah, you’re right it’s a tuple makes sense why it fails now…

How would you unpack the tuple? It’s not using std::tuple is it?

Something like replacing .toTensor() with .toTuple()->elements()[0].toTensor() will attempt to turn the first element in the tuple into a tensor object.

Not sure if that’s the recommended method.

1 Like

That worked! I’ll go with that unless someone else suggests otherwise.

You could also use std::get to get the elements of a tuple.
Not sure if that’s better than your current approach.

1 Like

Interesting to know about std::get, thanks.

how to get the 2rd output? when i use .toTuple->elements()[1].toTensor() with libtorch-1.1.0. thanks