Moving to numerically stable log-sum-exp leads to extremely large loss values

I am working on a network that uses a LSTM along with MDNs to predict some distributions. The loss function I use for these MDNs involve trying to fit my target data to the predicted distributions. I am trying to compute the log-sum-exp for the log_probs of these target data to compute the loss. When I use standard log-sum-exp, I get reasonable initial loss values (around 50-70) even though later it encounters some NaNs and breaks. Based on what I have read online, a numerically stable version of log-sum-exp is required to avoid this problem. However as soon I use the stable version, my loss values shoot up to the order of 15-20k. They do come down upon training but eventually they also lead to NaNs.

NOTE : I did not use the logsumexp function in PyTorch, since I needed to have a weighted summation based on my mixture components.

def log_sum_exp(self,value, weights, dim=None):
        eps = 1e-20
        m, idx = torch.max(value, dim=dim, keepdim=True)
        return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
                                       dim=dim) + eps)

def mdn_loss(self, pi, sigma, mu, target):
        eps = 1e-20
        target = target.unsqueeze(1)
        m = torch.distributions.Normal(loc=mu, scale=sigma)
        probs = m.log_prob(target)
        # Size of probs is batch_size x num_mixtures x num_out_features
        # Size of pi is batch_size x num_mixtures 
        loss = -self.log_sum_exp(probs, pi, dim=1)
        return loss.mean()

Upon adding anomaly_detection, the NaNs seem to occur at :
probs = m.log_prob(target)

Seeing these huge initial loss values just by moving to the numerically stable version have led me to believe I have some bug in my current implementation. Any help please.

Hi Umang!

I haven’t checked your weighted, stable log-sum-exp implementation
carefully, but, at first glance, it looks correct to me.

If your issues really do show up first with probs = m.log_prob(target),
then your log_sum_exp() is likely not the problem.

Note that the log in log_prob() itself can overflow, depending on
the value of target and on the distribution, in particular, the value
of sigma.

Here is a brief pytorch (version 0.3.0) session that illustrates this:

>>> import torch
>>> torch.distributions.Normal (0.0, 1.0).log_prob (torch.FloatTensor ([0, 10, 1.e5, 1.e10, 1.e19, 1.e20]))

-9.1894e-01
-5.0919e+01
-5.0000e+09
-5.0000e+19
-5.0000e+37
       -inf
[torch.FloatTensor of size 6]

>>> torch.distributions.Normal (0.0, 1.e-15).log_prob (torch.FloatTensor ([0, 10, 1.e5, 1.e10, 1.e19, 1.e20]))

 3.3620e+01
-5.0000e+31
       -inf
       -inf
       -inf
       -inf
[torch.FloatTensor of size 6]

The normal distribution falls off exponentially, so it is possible have
“reasonable” values of target (and sigma) for which the normal
distribution “underflows” to 0.0, at which point the log in log_prob()
will return -inf.

I assume that target is some well-defined, static data in your training
dataset. I would first check that your target data doesn’t contain any outlandish values.

Then I assume that your training is adjusting your value of sigma
(and mu). If your training, for some reason, is pushing sigma
towards zero, then your log_prob() will be getting pushed into
the far tails of the normal distribution (even for reasonable target
values), and you will get -inf, as illustrated above.

So I would check for small values of sigma, perhaps with something
as simple as

    if  sigma < 1.e-10:
        print ("Yikes!  Tiny sigma!)

If I understand your description correctly, it sounds like moving to your
numerically-stable log_sum_exp() postponed your training issues,
so I would say that that could be viewed as an improvement. (I don’t
see it as evidence for a bug in log_sum_exp().)

My guess is that your training is driving your model in an unhealthy
direction (e.g., sigma --> 0). I would check for that first.

Good luck.

K. Frank

Hi Frank,
Thanks a lot for your suggestions.
“I would first check that your target data doesn’t contain any outlandish values.”
Great advice, turns out I did have some outlandish target values. Removing them and normalizing the data has fixed the issue. However, I am still a bit confused as to why the loss gave values of such different magnitudes just by introducing the numerically stable log-sum-exp.

m, idx = torch.max(value, dim=dim, keepdim=True)
        return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
                                       dim=dim) + eps)

Does it have something to do with the max we are taking in this function?

Thanks,
Umang

Hello Umang!

This is just speculation on my part. I don’t have a clear picture of
the exact sequence of events, but it sounded to me like when you
used the naive log-sum-exp, your training broke early with NaNs.
And when you used the more stable log-sum-exp, your training
gave very large loss values, and then later broke with NaNs.

I assume that the naive version broke before you got to the large
loss values, so they were “hidden.” The stable version was able
to run further, so it was able to reach the large loss values. (Again,
just speculation).

Best.

K. Frank