Correct usage of torch::nll_loss

This toy program crashes when trying to call torch::nll_loss. It’s supposed to simulate a pixel-wise classification loss.

I was hoping that the python documentation https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss (the subsection on shape) would work for the c++ function.

Here is the toy program

#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>

int main(int argc, char** argv) {
  std::cout << "NLL Example" << std::endl;

  constexpr int kBatchSize = 1;
  at::Tensor net_output =
      torch::log_softmax(torch::randn({kBatchSize, 2, 300, 300}), /*dim=*/1);
  std::cout << "Log softmax succeeded" << std::endl;

  at::Tensor ground_truth = torch::ones({kBatchSize, 300, 300});
  std::cout << "Ground truth created" << std::endl;

  at::Tensor loss = torch::nll_loss(net_output, ground_truth);
  std::cout << "Done" << std::endl;
} 

The program crashes silently with the output

NLL Example
Log softmax succeeded
Ground truth created

Can you attach a debugger and share the stacktrace? Here an example usage from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/integration.cpp

  for (size_t epoch = 0; epoch < number_of_epochs; epoch++) {
    for (torch::data::Example<> batch : *data_loader) {
      auto data = batch.data.to(device), targets = batch.target.to(device);
      torch::Tensor prediction = forward_op(std::move(data));
      torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
      AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
      optimizer.zero_grad();
      loss.backward();
      optimizer.step();
    }
  }

Here is a capture from WinDbg

I caught the exception and printed e.what()

NLL Example
Log softmax succeeded
Ground truth created
Exception occured
Expected object of scalar type Long but got scalar type Float for argument #2 'target' (checked_tensor_unwrap at C:\w\1\s\windows\pytorch\aten\src\ATen/Utils.h:80)
(no backtrace available)

Based on the exception, I have updated the code to call .to(at::kLong) on the ground truth tensor. I now get different error (below). I think this means I need to look for a multi-target version of the nll function.

#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>

int main(int argc, char** argv) {
  std::cout << "NLL Example" << std::endl;

  constexpr int kBatchSize = 1;
  at::Tensor net_output =
      torch::log_softmax(torch::randn({kBatchSize, 2, 300, 300}), /*dim=*/1);
  std::cout << "Log softmax succeeded" << std::endl;

  at::Tensor ground_truth = torch::ones({kBatchSize, 300, 300}).to(at::kLong);
  std::cout << "Ground truth created" << std::endl;

  try {
    at::Tensor loss = torch::nll_loss(net_output, ground_truth);
  } catch (const std::exception& e) {
    std::cout << "Exception occured\n" << e.what() << std::endl;
    exit(-1);
  }
  std::cout << "Done" << std::endl;
}

The output is now

NLL Example
Log softmax succeeded
Ground truth created
Exception occured
multi-target not supported at C:\w\1\s\windows\pytorch\aten\src\THNN/generic/ClassNLLCriterion.c:22