Nan gradients when using F.nll_loss

Hello everyone!

I am trying to train a RBM in a discriminative way. The forward of the net compute the log-conditional probabilities. The normalization I need to perform in order to get the probabilities, however, does not involve a softmax (hence, I cannot use F.log_softmax) (see DRBM paper, p(y|x), at page 2).

In the problem I’m trying to solve, it is possible to have 0 probabilities. Let x be the binary input, x[i] = 0 implies p(y=i|x) = 0.
This is usually not a problem when using F.nll_loss(logits, target).backward(), since the 0 probability class will never be a target. Nevertheless, this seems to be a problem when using torch.log(logits/logits.sum(-1)).

I report here an example to reproduce the error:

w = torch.tensor([1.,2.,3.], requires_grad=True)
x = torch.tensor([[1,0,1]])
c = torch.tensor([0])

logits = w*x

If I had to normalize using softmax, Then I could use

l_mod = logits.clone()
l_mod[l_mod==0] = float("-inf")
print(F.log_softmax(l_mod, dim=-1)) 
# out: tensor([[-2.1269,    -inf, -0.1269]], grad_fn=<LogSoftmaxBackward>)
F.nll_loss(F.log_softmax(l_mod, dim=-1), c).backward() #or F.cross_entropy(...)
print(w.grad)
# out: tensor([-0.8808,  0.0000,  0.8808])

If, instead, I normalize without using the softmax, I get:

print(torch.log(logits/logits.sum(-1)))
#out: tensor([[-1.3863,    -inf, -0.2877]], grad_fn=<LogBackward>)
F.nll_loss(-torch.log(logits/logits.sum(-1)), c).backward()
print(w.grad)
#out: tensor([nan, nan, nan])

For the moment I solved the problem by either

  1. using (logits\logits.sum(-1)).clamp(min=1e-16) before taking the log
  2. Taking the log only of positive probabilities.:
    probs = logits\logits.sum(-1) return torch.log(probs[probs>0])

However my question remains: Why do I get nan without using such tricks? Are there better ways to avoid this problem from happening?

Thank you in advance!

Edit: The explanation I initially had was bogus.
It’s 0inf0 in a way (see below) and using log_softmax has the “correct limit” to avoid (inf*0) in the backward.

Hi Tom! first of all, thank you for you answer. Could you please point to me where, in particular, I end up doing -inf/-inf? When derivating the loss of my example (-1.3863) I should not encounter any -inf in the bacward steps, no? last but not least: I have -inf log-probabilities also in the log_softmax case. Why are the gradients well defined in that case?

Sorry, my bad…
So you get 0 logit and thus infty score and feed it in the nll. This gives you 0 grad in the backward for this entry, which then gets multiplied by 1/0=inf (the derivative of log at 0), and so 0*inf=NaN. This is the curse of backpropagation because next it’ll be multiplied by 0 again.
With log_softmax you would not run into the problem, because it can operate entirely in log-space (i.e. you have logsoftmax(x) = x - logsumexp(x) and so you avoid the derivative of log) and so it can tell that you only have 0.
But is your loss actually a vector of log probabilities? I would seem the logsumexp is inf due to the sign flip. It would need to be log probabilities for mathematical meaning.

Best regards

Thomas

Thank you a lot, Thomas! This helps to understand why the code works for F.log_softmax() but not with torch.log(...). Still, I was not understanding why the backward() pass had to go through the problematic derivative in the first place. In particular, I can avoid the error by doing the following:

w = torch.tensor([1.,2.,3.], requires_grad=True)
x = torch.tensor([[1,0,1]])
c = torch.tensor([0])

p = w*x / (w*x).sum(-1)
print(p)

nll = -torch.log(p[0][0]) #target class = c = 0
print(nll)

nll.backward()

print(w.grad)

which produces the following output:

tensor([[0.2500, 0.0000, 0.7500]], grad_fn=<DivBackward0>). #p
tensor(1.3863, grad_fn=<NegBackward>) #nll_loss
tensor([-0.7500,  0.0000,  0.2500]) #  gradients

As expected, the derivative of the nll_loss (equivalent to the example above) works fine, since no -inf score is involved in its computation. Instead, if I use the following to produce the log-probabilities (and the nll_loss), I get the nan gradients.

logits = torch.log(p)
nll = -logits[0][0]
print(nll)

nll.backward()

print(w.grad)

output:

tensor([[0.2500, 0.0000, 0.7500]], grad_fn=<DivBackward0>) #p (same as before!)
tensor(1.3863, grad_fn=<NegBackward>) #nll_loss (same as before!)
tensor([nan, nan, nan]) # gradients

The only difference with the previous examples is that in this case I build the computational graph for the useless (and problematic) case, which comes from torch.log(logits[0][1]). Even if I don’t care about this value in the nll_loss computation, it seems its backward graph is still processed when I do F.nll_loss(torch.log(p), [0]).backward().

Thank you again.