C++ API - Sequential::forward(..)

Hi there, I’m trying to write a simple character based RNN model and then train it. I finished the C++ frontend tutorial - but am getting a strange runtime error from within the depths of the Sequential::forward() function and I’m not sure how to debug. In terms of specifics of this build - I built PyTorch from source using 1.11.0 and I’m running on a Mac M1 (so I’m using an ARM64 build of PyTorch).

Specifically from the last statement in the following:

  template <typename ReturnType = Tensor, typename... InputTypes>
  ReturnType forward(InputTypes&&... inputs) {
    TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");

    auto iterator = modules_.begin();
    auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);

    for (++iterator; iterator != modules_.end(); ++iterator) {
      input = iterator->any_forward(std::move(input));

I’m getting the error:

c10::Error: Expected argument #0 to be of type at::Tensor, but received value of type std::__1::tuple<at::Tensor, at::Tensor>

with call-stack:

frame #8: 0x000000018558cbc8 libc++abi.dylib`__cxa_throw + 140
    frame #9: 0x00000001002440bc libc10.dylib`c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 120
    frame #10: 0x00000001002a4040 libcharrnn.dylib`std::__1::decay<at::Tensor const&>::type&& torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::CheckedGetter::operator(this=0x000000016fdfe820, index=<unavailable>)<at::Tensor const&>(unsigned long) at any_module_holder.h:48:7 [opt]
    frame #11: 0x00000001002a39c8 libcharrnn.dylib`torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::forward(std::__1::vector<torch::nn::AnyValue, std::__1::allocator<torch::nn::AnyValue> >&&) [inlined] torch::nn::AnyValue torch::unpack<torch::nn::AnyValue, at::Tensor const&, torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::InvokeForward, torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::CheckedGetter, 0ul>(function=InvokeForward @ 0x000000016fdfe840, accessor=CheckedGetter @ 0x000000016fdfe820) at variadic.h:140:48 [opt]
    frame #12: 0x00000001002a39bc libcharrnn.dylib`torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::forward(std::__1::vector<torch::nn::AnyValue, std::__1::allocator<torch::nn::AnyValue> >&&) [inlined] torch::nn::AnyValue torch::unpack<torch::nn::AnyValue, at::Tensor const&, torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::InvokeForward, torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::CheckedGetter>(function=InvokeForward @ x22, accessor=CheckedGetter @ x19) at variadic.h:132:21 [opt]
    frame #13: 0x00000001002a39bc libcharrnn.dylib`torch::nn::AnyModuleHolder<torch::nn::LinearImpl, at::Tensor const&>::forward(this=<unavailable>, arguments=size=1) at any_module_holder.h:106:12 [opt]
    frame #14: 0x00000001002a58d4 libcharrnn.dylib`torch::nn::AnyValue torch::nn::AnyModule::any_forward<torch::nn::AnyValue>(this=0x000000010207d7e8, arguments=0x000000016fdfe960) at any.h:277:20 [opt]
    frame #15: 0x00000001002a05cc libcharrnn.dylib`at::Tensor torch::nn::SequentialImpl::forward<at::Tensor, at::Tensor&>(this=0x0000000102085c70, inputs=0x000000016fdfe9e0) at sequential.h:181:25 [opt]
    frame #16: 0x000000010029f89c libcharrnn.dylib`RnnTrainer2::train(this=<unavailable>) at RnnTrainer2.cpp:399:26 [opt]
    frame #17: 0x0000000100002880 rnn-text`main(argc=<unavailable>, argv=<unavailable>) at main.cpp:25:12 [opt]

Here is the code I’m using (input is hardcoded as I just wanted to get something simple working):

void RnnTrainer2::train(){
  // Define the input string and output vector
  std::string input_str = "Hello, world!";
  std::vector<int64_t> output_vec = {8, 5, 12, 12, 15, 2, 23, 15, 18, 12, 4, 0};

  // Define the hyperparameters
  const int64_t input_size = 128;
  const int64_t hidden_size = 256;
  const int64_t output_size = 27;  // 26 characters plus padding

  // Define the RNN model
  torch::nn::RNNOptions rnn_options(input_size, hidden_size);
  rnn_options.num_layers(1);
  auto rnn = torch::nn::RNN(rnn_options);
  auto fc = torch::nn::Linear(hidden_size, output_size);
  auto model = torch::nn::Sequential(rnn, fc);

  // Define the loss function and optimizer
  auto loss_fn = torch::nn::CrossEntropyLoss();
  auto optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(1e-3));

  // Convert the input and output to tensors
  long long inputStrLength = (long long)input_str.length();
  //non-constant-expression cannot be narrowed from type 'std::basic_string<char>::size_type' (aka 'unsigned long') to 'long long' in initializer list [-Wc++11-narrowing]
  //auto input = torch::zeros({input_str.length(), 1, input_size});
  torch::Tensor input = torch::zeros({inputStrLength, 1, input_size});
  for (int i = 0; i < input_str.length(); i++) {
    input[i][0][input_str[i]] = 1;
  }
  long long outputVecSize = (long long)output_vec.size();
  //non-constant-expression cannot be narrowed from type 'std::vector<long long>::size_type' (aka 'unsigned long') to 'long long' in initializer
  //auto output = torch::from_blob(output_vec.data(), {output_vec.size()}, torch::kLong);
  auto output = torch::from_blob(output_vec.data(), {outputVecSize}, torch::kLong);

  // Train the model
  const int64_t num_epochs = 1000;
  for (int epoch = 0; epoch < num_epochs; epoch++) {
    auto logits = model->forward(input).squeeze(1);
    auto loss = loss_fn(logits, output);
    optimizer.zero_grad();
    loss.backward();
    optimizer.step();
    if (epoch % 100 == 0) {
      std::cout << "Epoch " << epoch << ", Loss: " << loss.item<float>() << std::endl;
    }
  }

}

Thanks for any suggestions!

Will