Nan loss in RNN model?

Hi, I’m trying out the code from the awesome practical-python codes. And I’m replacing the text with a slightly bigger one (originally 164KB, and mine is 966KB).

However, the loss becomes nan after several iterations.

the model by @spro is below.

import torch.nn as nn
from torch.autograd import Variable

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax()
    
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return Variable(torch.zeros(1, self.hidden_size))

Then I replace the LogSoftmax with softmax+log(output+eps) like that

import torch.nn as nn
from torch.autograd import Variable

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
    
        self.softmax = nn.Softmax()
    
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        
        output = self.softmax(output)
        output = output.add(1e-8)
        output = output.log()
        
        return output, hidden

    def initHidden(self):
        return Variable(torch.zeros(1, self.hidden_size).type(dtypeFloat))

The result is slightly better, but still end up nan

Is anybody could provide me with some suggestion on how such thing could happen?
Thanks.

Besides, another problem is how to get nan in PyTorch.

Ok, this is solved, since the type(nan) shows that it’s float, which could be examined with math.isnan(x)

But the NaN problem above is remained unsolved…

The NaNs appear, because softmax + log separately can be a numerically unstable operation.

If you’re using CrossEntropyLoss for training, you could use the F.log_softmax function at the end of your model and use NLLLoss. The loss will be equivalent, but much more stable.

8 Likes

Thanks for reply. But I have added an epsilon to the log and my loss function was NLLLoss.

And I used torch.nn.LogSoftmax and NLLLoss at the beginning, its result was the worst one, and that is why I separated softmax and log operation…

def cross_entropy(input, target, weight=None, size_average=True):
    r"""This criterion combines `log_softmax` and `nll_loss` in one single class.
    See :class:`torch.nn.CrossEntropyLoss` for details.
    Args:
        input: Variable :math:`(N, C)` where `C = number of classes`
        target: Variable :math:`(N)` where each value is `0 <= targets[i] <= C-1`
        weight (Variable, optional): a manual rescaling weight given to each
                class. If given, has to be a Variable of size "nclasses"
        size_average (bool, optional): By default, the losses are averaged
                over observations for each minibatch. However, if the field
                sizeAverage is set to False, the losses are instead summed
                for each minibatch.
    """
    return nll_loss(log_softmax(input), target, weight, size_average)

Shouldn’t CrossEntropyLoss behave the same as log_softmax+NLLLoss ?

2 Likes

I have the same problem with you, and replace cross_entropy with log_softmax + nll_loss doesn’t work.

It seems to me that nan is more likely to happen when the network is big; when I try the same architecture on a much smaller scale the nan disappear

yep, I set eps to 1e-6 and I think it can not be explained with “numerically unstable”, since the numerically unstable may result in problematic problem which have difficult converge, but it cannot explain a negative output of exp, which is totally wrong.

I further checked my codes and find out that my problem is due to the gradient explosion of RNN. This code might help you if the cause is the same as mine, pay attention to the function ‘clip_gradient’

9 Likes

Thanks, that is exactly what I’ve missed!

Hi all,

I would agree that it is very likely due to the gradient explosion.

https://machinelearningmastery.com/exploding-gradients-in-neural-networks/
This URL is very helpful.
Besides clipping, this URL also suggest using ReLU, LSTM, etc…

Wish it helps!
Yisong

2 Likes

Thank you @splinter , this is very helpful and successfully rid my LSTM of NaNs.

To abbreviate @splinter’s answer a little, the trick is to call torch.nn.utils.clip_grad_norm_ just after loss.backward like this :

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

Thanks @Iridium_Blue @splinter helped me a lot

1 Like

I notice that the clip_grad_norm trick appears frequently in Cutting edge solutions like fastai; aside from preventing NaNs it helps to correct a fundamental weakness of any RNN, that of vanishing and exploding gradients. This can greatly improve the models performance.

Just as a suggestion for someone that have this problem, make sure you check your inputs thoroughly and make sure they do not contain NAN. I had this same problem, I spent a week trying all different solutions such as gradient clipping just to figure out that my problem was that one of the input values is NAN which throws off my calculations leading to a loss thats NAN. :grinning:

2 Likes