Hi,
I’m trying to generate time-series data with an LSTM and a Mixture Density Network as described in https://arxiv.org/pdf/1308.0850.pdf
Here is a link to my implementation: GitHub - NeoVand/MDNLSTM
The repository contains a toy dataset to train the network.
On training, the LSTM layer returns nan for its hidden state after one iteration. There is a similar issue here:
The issue was caused by the log-sum-exp operation not being done in a stable way. Here is an implementation of a weighted log-sum-exp trick that I used and could fix the problem:
def weighted_logsumexp(x,w, dim=None, keepdim=False):
if dim is None:
x, dim = x.view(-1), 0
xm, _ = torch.max(x, dim, keepdim=True)
x = torch.where(
# to prevent nasty nan's
(xm == float('inf')) | (xm == float('-inf')),
xm,
xm + torch.log(torch.sum(torch.exp(x - xm)*w, dim, keepdim=True)))
return x if keepdim else x.squeeze(dim)
and using that implemented the stable loss function:
def mdn_loss_stable(y,pi,mu,sigma):
m = torch.distributions.Normal(loc=mu, scale=sigma)
m_lp_y = m.log_prob(y)
loss = -weighted_logsumexp(m_lp_y,pi,dim=2)
return loss.mean()
This worked like a charm. In general, the problem is that torch won’t report under-flows.