# Applying log_softmax manually in a stable way

Hi all,

I have a multiclass classification problem and my network structure is a bit complex than usual. In a nutshell, I have 2 types of sets for labels. The ground-truth is always one label from one of the sets. (think like, labels from 0 to C are from one set and labels from C+1 to N are from another set)

My network calculates 2 diferent logits for each set with different architectures. What I do is:

I convert these logits to probability distributions via softmax and now I have 2 probability distributions one for each target set: p1 and p2.

I have a learnable scalar s(in range [0,1], which weights the learnt probability distributions.
I calculate p1 = s * p1 and p2 = (1-s) * p2

Basically, depending on the case, I generate the label from one of the sets. At the end I concetanete both distributions and it is still a valid probability distribution since s is in range [0,1]. Lastly, I want to compute the loss on that.

I cannot use cross entropy loss because it requires raw logits. Therefore I have to use NLLLoss, However, it requires log probabilities. Since I already applied softmax, I have to apply log manually but this leads to instabilities and numerical issues as mentioned in the documentation.

But I need to apply the softmax in the first place because if i apply log_softmax then I do not have valid probability distributions and I can not weight them in a reasonable way with parameter s.

Any suggestions ?

wonât something like x.clamp_min(1e-6).log() be enough?

Hi Berkay!

Perform your `s`, `(1 - s)` weighting in âlog spaceâ so that you work
directly with the log-probabilities and only have to call `log_softmax()`,
with its better numerical stability.

That is, because:

``````log (s * prob) = log (s) + log_prob,
``````

just add `log (s)` (and `log (1 - s)`) to your results of `log_softmax()`,
rather that multiplying the results of `softmax()` with `s` (and `(1 - s)`).

As an aside, after computing the log-probabilities as outlined above, you
could use `CrossEntropyLoss` if you wanted to. The log-probabilities
are legitimate raw-score logits â theyâve just already been ânormalizedâ
(whatever I mean by that). In my mind itâs a little more conceptually
straightforward to feed your log-probabilities into `NLLLoss`, but youâll
get the same answer if you use `CrossEntropyLoss`.

Here is a script that illustrates the above:

``````import torch
torch.__version__

import math

torch.manual_seed (2020)

# 7 logits
# first group contains 4, second group 3
N = 7
C = 4
nBatch = 2   # batch size of 2, for example

s = 1/3

logits = torch.randn ((nBatch, N))
logits

lp1 = torch.nn.functional.log_softmax (logits[:, :C], dim = 1)  # log-probs for first group
lp2 = torch.nn.functional.log_softmax (logits[:, C:], dim = 1)  # log-probs for second group

log_probs = torch.cat ((math.log (s) + lp1, math.log (1.0 - s) + lp2), dim = 1)
log_probs   # all 7 log-probs

torch.exp (log_probs).sum (dim = 1)   # check that it's a valid probability distribution

# the log_probs are legitimate logits -- verify that log_softmax() doesn't change them
torch.nn.functional.log_softmax (log_probs, dim = 1)
``````

And here is its output:

``````>>> torch.__version__
'1.6.0'
>>>
>>> import math
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x7eff574ed930>
>>>
>>> # 7 logits
>>> # first group contains 4, second group 3
>>> N = 7
>>> C = 4
>>> nBatch = 2   # batch size of 2, for example
>>>
>>> s = 1/3
>>>
>>> logits = torch.randn ((nBatch, N))
>>> logits
tensor([[ 1.2372, -0.9604,  1.5415, -0.4079,  0.8806,  0.0529,  0.0751],
[ 0.4777, -0.6759, -2.1489, -1.1463, -0.2720,  1.0066, -0.0416]])
>>>
>>> lp1 = torch.nn.functional.log_softmax (logits[:, :C], dim = 1)  # log-probs for first group
>>> lp2 = torch.nn.functional.log_softmax (logits[:, C:], dim = 1)  # log-probs for second group
>>>
>>> log_probs = torch.cat ((math.log (s) + lp1, math.log (1.0 - s) + lp2), dim = 1)
>>> log_probs   # all 7 log-probs
tensor([[-2.0768, -4.2745, -1.7725, -3.7219, -1.0388, -1.8665, -1.8443],
[-1.5592, -2.7127, -4.1858, -3.1831, -2.1721, -0.8934, -1.9416]])
>>>
>>> torch.exp (log_probs).sum (dim = 1)   # check that it's a valid probability distribution
tensor([1.0000, 1.0000])
>>>
>>> # the log_probs are legitimate logits -- verify that log_softmax() doesn't change them
>>> torch.nn.functional.log_softmax (log_probs, dim = 1)
tensor([[-2.0768, -4.2745, -1.7725, -3.7219, -1.0388, -1.8665, -1.8443],
[-1.5592, -2.7127, -4.1858, -3.1831, -2.1721, -0.8934, -1.9416]])
>>>
``````

Best.

K. Frank

1 Like

Dear KFrank,

Thank you very much for your reply. It makes perfect sense what you said. Now, I implemented my model as you said and it is fine. However, now the loss becomes nan after awhile. I ran torch anomaly detection and the problem is caught there. It tells me

RuntimeError: Function âLogBackwardâ returned nan values in its 0th output.

and the last called line is the following: (Basically where I do log(1-s))

local_distribution = torch.log(1-st) + F.log_softmax(attention_weights.squeeze(dim=2), dim=1)

perturbation = 1e-6
w1 = log(s + perturbation)
w2 = log(1 - s + perturbation)

Do you think that it makes sense or any do you have any other suggestions ?

Hi Berkay!

Have you verified that the divergence occurs in `log (1 - s)`, and
not, perhaps in a `log()` embedded in `log_softmax()`? Can you
check that `s` has become greater than (or equal to or close to) `1.0`?

This would likely cure the immediate problem, but I would probably
use `torch.clamp()`, rather than adding `perturbation`.

Is there anything in your training or the structure of your problem
that âknowsâ that `s` is supposed to be a probability, and keeps it
between `0` and `1`? Your training might happily push `s` outside its
valid range, leading to `nan`s in `log()`.

I would would suggest training the logit of `s`. (It will run from `-inf` to
`inf`.) Letâs call it `t`. Then replace `log (s)` and `log (1 - s)` with:

``````torch.nn.logsigmoid (t)
torch.nn.logsigmoid (1.0 - t)
``````

(where weâre using pytorchâs `logsigmoid()` to avoid the potential
numerical instability of calling `log (sigmoid (t))` in two steps).

Best.

K. Frank

1 Like

Hi KFrank,

Thanks again for your reply. Yes, I actually verified that problem is coming from the log and yes s was becoming 1.

I was already using a sigmoid function for s. So yes, in my model there exist a layer that ensures that it is in the range [0,1] and s is a learnable parameter. However, I was applying the log separately, after the sigmoid. I did not know that there exist a LogSigmoid. Now i am using it and it works just fine.

Thank you very much for your help !
Best,
Berkay