Expected Tensor (not Variable) for argument #2 'mat2'

Hey I was starting out with pytorch c++ frontend. I am facing the following error

terminate called after throwing an instance of 'c10::Error'
  what():  Expected Tensor (not Variable) for argument #2 'mat2' (checked_tensor_unwrap at /pytorch/aten/src/ATen/Utils.h:77)
frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7fbfd7e49bf1 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libc10.so)
frame #1: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x2a (0x7fbfd7e4952a in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libc10.so)
frame #2: <unknown function> + 0xcbd069 (0x7fbfce585069 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #3: <unknown function> + 0xccc95d (0x7fbfce59495d in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #4: at::native::mm(at::Tensor const&, at::Tensor const&) + 0x65 (0x7fbfce37cda5 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #5: at::TypeDefault::mm(at::Tensor const&, at::Tensor const&) const + 0x5d (0x7fbfce72489d in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #6: <unknown function> + 0xabac05 (0x7fbfce382c05 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #7: at::native::matmul(at::Tensor const&, at::Tensor const&) + 0x27 (0x7fbfce3834e7 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #8: at::TypeDefault::matmul(at::Tensor const&, at::Tensor const&) const + 0x5d (0x7fbfce7260bd in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #9: at::native::linear(at::Tensor const&, at::Tensor const&, at::Tensor const&) + 0x71 (0x7fbfce371e21 in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #10: at::TypeDefault::linear(at::Tensor const&, at::Tensor const&, at::Tensor const&) const + 0x6c (0x7fbfce726b2c in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libcaffe2.so)
frame #11: torch::nn::LinearImpl::forward(at::Tensor const&) + 0x5d (0x7fbfd8c47c3d in /home/vipul/Documents/vaibhawvipul/DeepLearning/libtorch/lib/libtorch.so.1)
frame #12: Net::forward(at::Tensor) + 0x11a (0x4313de in ./simpleNN)
frame #13: main + 0x496 (0x42d47c in ./simpleNN)
frame #14: __libc_start_main + 0xf0 (0x7fbfccd69830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #15: _start + 0x29 (0x428949 in ./simpleNN)

I am training a simple neural network on iris dataset, just to learn c++ frontend. Here is my code for neural network -

struct Net : torch::nn::Module {
  Net() {
    // Construct and register two Linear submodules.
    fc1 = register_module("fc1", torch::nn::Linear(4, 4));
    fc2 = register_module("fc2", torch::nn::Linear(4, 3));
    fc3 = register_module("fc3", torch::nn::Linear(3, 3));
  }

  // Implement the Net's algorithm.
  torch::Tensor forward(at::Tensor x) {
    // Use one of many tensor manipulation functions.
    x = fc1->forward(x.reshape(x.size(0)));
    std::cout << x << std::endl;
    x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
    std::cout << x << std::endl;
    x = torch::relu(fc2->forward(x));
    x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
    return x;
  }

  // Use one of many "standard library" modules.
  torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

and the following code is the main() function -

int main() {

  //loading data
  DataSet data_set("/DeepLearning/simpleNNPytorchCpp/iris.csv");

  //loading model 
  auto net = std::make_shared<Net>();

  //setting up optimizer
  torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

  for (size_t epoch = 1; epoch <= 1; ++epoch) {
    size_t batch_index = 0;
    // Iterate the data loader to yield batches from the dataset.
    for (unsigned i = 0; i < data_set.x1().size(); i++) {

      //form an input tensor combined of x1, x2, x3 and x4
      std::vector<float> input_vector;
      input_vector.push_back(data_set.x1()[i]);
      input_vector.push_back(data_set.x2()[i]);
      input_vector.push_back(data_set.x3()[i]);
      input_vector.push_back(data_set.x4()[i]);

      at::Tensor input_tensor = at::from_blob(input_vector.data(),{long(input_vector.size()),1});
      input_tensor = input_tensor.toType(at::kFloat);

      std::vector<float> target_output;
      target_output = data_set.y()[i];

      at::Tensor output_tensor = at::from_blob(target_output.data(),{long(target_output.size()),1});
      output_tensor = output_tensor.toType(at::kFloat);

      // Reset gradients.
      optimizer.zero_grad();
      // Execute the model on the input data.
      torch::Tensor prediction = net->forward(input_tensor);
      // Compute a loss value to judge the prediction of our model.
      torch::Tensor loss = torch::nll_loss(prediction,output_tensor);
      // Compute gradients of the loss w.r.t. the parameters of our model.
      loss.backward();
      // Update the parameters based on the calculated gradients.
      optimizer.step();
      // Output the loss and checkpoint every 100 batches.
      if (++batch_index % 100 == 0) {
        std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                  << " | Loss: " << loss.item<float>() << std::endl;
        // Serialize your model periodically as a checkpoint.
        torch::save(net, "net.pt");
      }
    }
  }

  return 0;
}

can someone please help me out?

@ptrblck can you help me out please?

I guess the error might be thrown because you are using at::Tensor for your input as well as in forward instead of torch::Tensor.

PS: Tagging certain people might discourage others to answer in your thread and I’m pretty sure there are more experienced C++ frontend users here. :wink:

1 Like

Thanks @ptrblck for help.

I will keep it in mind, i won’t tag until absolutely necessary! Sorry about that. :slight_smile: