I am working on a network that uses a LSTM along with MDNs to predict some distributions. The loss function I use for these MDNs involve trying to fit my target data to the predicted distributions. I am trying to compute the log-sum-exp for the log_probs of these target data to compute the loss. When I use standard log-sum-exp, I get reasonable initial loss values (around 50-70) even though later it encounters some NaNs and breaks. Based on what I have read online, a numerically stable version of log-sum-exp is required to avoid this problem. However as soon I use the stable version, my loss values shoot up to the order of 15-20k. They do come down upon training but eventually they also lead to NaNs.
NOTE : I did not use the logsumexp function in PyTorch, since I needed to have a weighted summation based on my mixture components.
def log_sum_exp(self,value, weights, dim=None): eps = 1e-20 m, idx = torch.max(value, dim=dim, keepdim=True) return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)), dim=dim) + eps) def mdn_loss(self, pi, sigma, mu, target): eps = 1e-20 target = target.unsqueeze(1) m = torch.distributions.Normal(loc=mu, scale=sigma) probs = m.log_prob(target) # Size of probs is batch_size x num_mixtures x num_out_features # Size of pi is batch_size x num_mixtures loss = -self.log_sum_exp(probs, pi, dim=1) return loss.mean()
Upon adding anomaly_detection, the NaNs seem to occur at :
probs = m.log_prob(target)
Seeing these huge initial loss values just by moving to the numerically stable version have led me to believe I have some bug in my current implementation. Any help please.