Applying log_softmax manually in a stable way

Hi all,

I have a multiclass classification problem and my network structure is a bit complex than usual. In a nutshell, I have 2 types of sets for labels. The ground-truth is always one label from one of the sets. (think like, labels from 0 to C are from one set and labels from C+1 to N are from another set)

My network calculates 2 diferent logits for each set with different architectures. What I do is:

I convert these logits to probability distributions via softmax and now I have 2 probability distributions one for each target set: p1 and p2.

I have a learnable scalar s(in range [0,1], which weights the learnt probability distributions.
I calculate p1 = s * p1 and p2 = (1-s) * p2

Basically, depending on the case, I generate the label from one of the sets. At the end I concetanete both distributions and it is still a valid probability distribution since s is in range [0,1]. Lastly, I want to compute the loss on that.

I cannot use cross entropy loss because it requires raw logits. Therefore I have to use NLLLoss, However, it requires log probabilities. Since I already applied softmax, I have to apply log manually but this leads to instabilities and numerical issues as mentioned in the documentation.

But I need to apply the softmax in the first place because if i apply log_softmax then I do not have valid probability distributions and I can not weight them in a reasonable way with parameter s.

Any suggestions ?

won’t something like x.clamp_min(1e-6).log() be enough?

Hi Berkay!

Perform your s, (1 - s) weighting in “log space” so that you work
directly with the log-probabilities and only have to call log_softmax(),
with its better numerical stability.

That is, because:

log (s * prob) = log (s) + log_prob,

just add log (s) (and log (1 - s)) to your results of log_softmax(),
rather that multiplying the results of softmax() with s (and (1 - s)).

As an aside, after computing the log-probabilities as outlined above, you
could use CrossEntropyLoss if you wanted to. The log-probabilities
are legitimate raw-score logits – they’ve just already been “normalized”
(whatever I mean by that). In my mind it’s a little more conceptually
straightforward to feed your log-probabilities into NLLLoss, but you’ll
get the same answer if you use CrossEntropyLoss.

Here is a script that illustrates the above:

import torch
torch.__version__

import math

torch.manual_seed (2020)

# 7 logits
# first group contains 4, second group 3
N = 7
C = 4
nBatch = 2   # batch size of 2, for example

s = 1/3

logits = torch.randn ((nBatch, N))
logits

lp1 = torch.nn.functional.log_softmax (logits[:, :C], dim = 1)  # log-probs for first group
lp2 = torch.nn.functional.log_softmax (logits[:, C:], dim = 1)  # log-probs for second group

log_probs = torch.cat ((math.log (s) + lp1, math.log (1.0 - s) + lp2), dim = 1)
log_probs   # all 7 log-probs

torch.exp (log_probs).sum (dim = 1)   # check that it's a valid probability distribution

# the log_probs are legitimate logits -- verify that log_softmax() doesn't change them
torch.nn.functional.log_softmax (log_probs, dim = 1)

And here is its output:

>>> torch.__version__
'1.6.0'
>>>
>>> import math
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x7eff574ed930>
>>>
>>> # 7 logits
>>> # first group contains 4, second group 3
>>> N = 7
>>> C = 4
>>> nBatch = 2   # batch size of 2, for example
>>>
>>> s = 1/3
>>>
>>> logits = torch.randn ((nBatch, N))
>>> logits
tensor([[ 1.2372, -0.9604,  1.5415, -0.4079,  0.8806,  0.0529,  0.0751],
        [ 0.4777, -0.6759, -2.1489, -1.1463, -0.2720,  1.0066, -0.0416]])
>>>
>>> lp1 = torch.nn.functional.log_softmax (logits[:, :C], dim = 1)  # log-probs for first group
>>> lp2 = torch.nn.functional.log_softmax (logits[:, C:], dim = 1)  # log-probs for second group
>>>
>>> log_probs = torch.cat ((math.log (s) + lp1, math.log (1.0 - s) + lp2), dim = 1)
>>> log_probs   # all 7 log-probs
tensor([[-2.0768, -4.2745, -1.7725, -3.7219, -1.0388, -1.8665, -1.8443],
        [-1.5592, -2.7127, -4.1858, -3.1831, -2.1721, -0.8934, -1.9416]])
>>>
>>> torch.exp (log_probs).sum (dim = 1)   # check that it's a valid probability distribution
tensor([1.0000, 1.0000])
>>>
>>> # the log_probs are legitimate logits -- verify that log_softmax() doesn't change them
>>> torch.nn.functional.log_softmax (log_probs, dim = 1)
tensor([[-2.0768, -4.2745, -1.7725, -3.7219, -1.0388, -1.8665, -1.8443],
        [-1.5592, -2.7127, -4.1858, -3.1831, -2.1721, -0.8934, -1.9416]])
>>>

Best.

K. Frank

2 Likes

Dear KFrank,

Thank you very much for your reply. It makes perfect sense what you said. Now, I implemented my model as you said and it is fine. However, now the loss becomes nan after awhile. I ran torch anomaly detection and the problem is caught there. It tells me

RuntimeError: Function ‘LogBackward’ returned nan values in its 0th output.

and the last called line is the following: (Basically where I do log(1-s))

local_distribution = torch.log(1-st) + F.log_softmax(attention_weights.squeeze(dim=2), dim=1)

I thought about adding a small perturbation to both logs. Somethinglike:

perturbation = 1e-6
w1 = log(s + perturbation)
w2 = log(1 - s + perturbation)

Do you think that it makes sense or any do you have any other suggestions ?

Hi Berkay!

Have you verified that the divergence occurs in log (1 - s), and
not, perhaps in a log() embedded in log_softmax()? Can you
check that s has become greater than (or equal to or close to) 1.0?

This would likely cure the immediate problem, but I would probably
use torch.clamp(), rather than adding perturbation.

Is there anything in your training or the structure of your problem
that “knows” that s is supposed to be a probability, and keeps it
between 0 and 1? Your training might happily push s outside its
valid range, leading to nans in log().

I would would suggest training the logit of s. (It will run from -inf to
inf.) Let’s call it t. Then replace log (s) and log (1 - s) with:

torch.nn.logsigmoid (t)
torch.nn.logsigmoid (1.0 - t)

(where we’re using pytorch’s logsigmoid() to avoid the potential
numerical instability of calling log (sigmoid (t)) in two steps).

Best.

K. Frank

2 Likes

Hi KFrank,

Thanks again for your reply. Yes, I actually verified that problem is coming from the log and yes s was becoming 1.

I was already using a sigmoid function for s. So yes, in my model there exist a layer that ensures that it is in the range [0,1] and s is a learnable parameter. However, I was applying the log separately, after the sigmoid. I did not know that there exist a LogSigmoid. Now i am using it and it works just fine.

Thank you very much for your help !
Best,
Berkay