In the python API, the NLLLoss is allowed to take a target shape (N, d1, …, dk). However, in the c++ api, the torch::nll_loss will crash with an exception multi-target not supported at C:\w\1\s\windows\pytorch\aten\src\THNN/generic/ClassNLLCriterion.c:22
NLL Example
Log softmax succeeded
Ground truth created
Exception occured
multi-target not supported at C:\w\1\s\windows\pytorch\aten\src\THNN/generic/ClassNLLCriterion.c:22```
This error is usually thrown, if your target still contains the class dimension:
criterion = nn.NLLLoss()
N, nb_classes = 2, 3
output = torch.randn(N, nb_classes, requires_grad=True)
target = torch.randint(0, nb_classes, (N,))
loss = criterion(output, target) # works
target = torch.randint(0, nb_classes, (N, nb_classes))
loss = criterion(output, target) # fails with same error
For a multi-label classification (a single sample can contain more than a single valid class), you should use nn.BCEWithLogitsLoss (or the C++ equivalent).
Thanks @ptrblck. It looks like you are using the form of the target with dimension (N,). However, I need the form where the target has dimension (N, d1, … dk). More explicitly (N, 300, 300). I am trying to train a semantic segmentation task where each pixel in the image has its own label.
I assume the semantic segmentation is a multi-class segmentation, i.e. each pixel belongs to exactly one class.
If so, this should be possible using these shapes
N, nb_classes, H, W = 2, 3, 4, 4
output = torch.randn(N, nb_classes, H, W, requires_grad=True)
target = torch.randint(0, nb_classes, (N, H, W))
loss = criterion(output, target)
Note that the nb_classes dimension is missing in the target as described in the docs, so make sure to pass these shapes to your criterion.
nn.BCEWithLogitsLoss is most likely not, what you are looking for.
I understand what you are saying, but I can’t see where my code sample is broken. My input is a tensor of dimension (1, 2, 300, 300) (1 batch, 2 classes, 300x300 image). My target (ground_truth) is a tensor (1, 300, 300) filled with 1.
This seems to be indeed the right shape and I remembered I’ve seen this issue before!
Could you check, if nll_loss2d is defined and if so use it instead of nll_loss (I’m currently not on my machine to check it)?
Unfortunately there is no documentation besides the function declaration.
Hopefully it will work as a drop-in replacement in my code.
Edit: It worked! Well at least it succeeded without crashing. I’m now manually checking the output to see if it does what I think it should do (pixel-wise NLL loss, and then taking the mean)
Awesome!
Yeah, we are working on the parity of the C++ API, so these things might be missing currently.
Fortunately, we have the discussion board so feel free to ask, if you get stuck somewhere
I have the same semantic segmentation problem that @markl has, but my images are three-dimensional. Therefore, I would need a loss function that takes target tensors of shape (N, d1, d2, d3). There doesn’t seem to be a torch::nll_loss3d() in C++ Pytorch, as far as I can tell. What is your recommendation in this case? Do I just have to wait until libTorch catches up with python? If so, any idea when that will happen?
Thanks for your answer! Is that supported after some given version? Because I’m using 1.2.0 and NLLLoss only accepts rank-2 tensors, as described by @markl up above in this thread. If you try a higher rank tensor, you get the multi-target not supported error.
#include <torch/torch.h>
#include <iostream>
int main() {
auto input = torch::randn({3, 10, 24, 24});
auto target = torch::randint(0, 10, {3, 24, 24}, torch::kLong);
torch::nn::NLLLoss criterion;
auto loss = criterion->forward(input, target);
std::cout << loss << std::endl;
}
Also, note that there are different implementations, such as LossNLL2d.cpp, which might accept multiple dimensions (I haven’t checked, which function is called in my example though).
Hi, thanks for the response. Yes, your code works for me too. What doesn’t work is to use the free-function form of the NLL loss:
auto input = torch::randn({3, 10, 24, 24});
auto target = torch::randint(0, 10, {3, 24, 24}, torch::kLong);
auto loss = torch::nll_loss(input, target);
Is there any design reason for this to behave differently from NLLLoss? If not, I suggest it either gets brought up-to-par with the class, or removed altogether, to avoid confusion. Thanks again!