Basic NLLLoss Calculation Question

Hi, I am confused about how to use torch.nn.NLLLoss. Below is a simple session in the Python REPL. I am expecting the result to be 0.35667494393873245, but I am getting -0.7000. I’d greatly appreciated it if someone could steer me in the right direction on this. Thanks!

>>> import torch
>>> import torch.nn as nn
>>> input = torch.tensor([[0.70, 0.26, 0.04]])
>>> loss = nn.NLLLoss()
>>> target = torch.tensor([0])
>>> output = loss(input, target)
>>> output
>>> import math
>>> -math.log(0.70)

I think I can see what’s happening now. I was a bit confused about how NLLLoss works. The calculation below shows that applying the negative log likelihood to an input processed through softmax produces the same result as running the input through log_softmax first, then just multiplying by -1. It also shows that applying CrossEntropyLoss to the raw_input is the same as applying NLLLoss to log_softmax(input). I’m guessing that the log_softmax approach may be more numerically stable than using softmax first, and calculating the log of the result separately.

>>> raw_input = torch.tensor([[0.7, 0.26, 0.04]])
>>> softmax_input = torch.softmax(raw_input, dim=1)
>>> softmax_input
tensor([[0.4628, 0.2980, 0.2392]])
>>> -torch.log(softmax_input)
tensor([[0.7705, 1.2105, 1.4305]])
>>> log_softmax_input = torch.log_softmax(raw_input, dim=1)
>>> log_softmax_input * -1
tensor([[0.7705, 1.2105, 1.4305]])
>>> loss = nn.NLLLoss()
>>> loss(log_softmax_input, torch.tensor([0]))
>>> cross_entropy_loss = nn.CrossEntropyLoss()
>>> cross_entropy_loss(raw_input, torch.tensor([0]))

I understand from your experiment that F.NLL_Loss does not expect the likelihood as input, but the log likelihood (log softmax), do you agree with this assessment?

If that is so, I guess this is a source of lots of errors because I believe most people would guess (wrongly) that you should forward pass the likelihood to F.NLL_Loss. Right?

1 Like

Yes, indeed. The documentation does state this, though, so I either didn’t read it properly at the time or misinterpreted it. :

The input given through a forward call is expected to contain log-probabilities of each class.