What's the corresponding function to torch.nn.CrossEntropyLoss() in C++ API?

(Yao Zihang) #1

I want to use torch.nn.CrossEntropyLoss to compute loss in C++, but I couldn’t find a corresponding function in C++ API, only binary_cross_entropy is found. Is there a function corresponding to torch.nn.CrossEntropyLoss?

(kerry.cho) #2

I also could not find the function. Try the function below.
torch::nll_loss(out.log_softmax(1), targets);

(Yao Zihang) #3

Thanks for you advice!

1 Like