Why my model returns nan?

The model is here:

class Actor(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=512):
        super(Actor, self).__init__()
        self.state_size = state_size
        self.hidden_size = hidden_size
        self.action_size = action_size
        self.block_state = nn.Sequential(
            nn.Linear(state_size, hidden_size),
        self.block_hidden = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Linear(hidden_size, hidden_size),
        self.block_mean = nn.Sequential(
            nn.Linear(hidden_size, action_size),
        self.block_std = nn.Sequential(
            nn.Linear(hidden_size, action_size),
    def forward(self, state):
        out = self.block_state(state)
        out = self.block_hidden(out)
        mean = self.block_mean(out)
        std = self.block_std(out)
        return mean,std

The output is:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan]], grad_fn=)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan]], grad_fn=)
I’m sure the input doesn’t contain any nan value.

There are many potential reasons. Most likely exploding gradients. The two things to try first:

  1. Normalize the inputs

  2. Lower the learning rate


I will try to normalize the inputs. See whether it works.
It can’t be learning rate problem, since it happened at the beginning.

1 Like

Could you additionally check your input for inf of nan values?


The input doesn’t contain any nan value. I guess something goes wrong in the block_hidden since both the block_mean and block_std contains nan value.

1 Like

@DXZ_999 @rasbt
Hello, there is another possibility: If the output contain some large values (abs(value) > 1e20), then nn.LayerNorm(output) might return a all nan vector.

Similiar problem happens in my attention model, I’m pretty sure that it can’t be exploding gradients in my model because:

  1. The model can converge (after some iteration, the model loss will be low and stable)
  2. Debug result shows that only a limited number of samples has this problem.
  3. Frequency is so rare that I have to use torch.any(torch.isnan(x)) to catch this bug, and even with this, it require multiple runs to catch one examples.
    4.Only intermediate result become nan, input normalization is implemented but problem still exist.

My model handle time-series sequence, if there are one vector ‘infected’ with nan, it will propagate and ruin the whole output, so I would like to know whether it is a bug or any solution to address it.


This might sound weird but restart your machine. I was facing some issue with the GPU and had to restart the system and to my surprise, it started training.

I was facing similar issue, and passing each sample through the function torch.nan_to_num() did the trick. I noticed very few samples were having nan.

1 Like

Please check the weights. The weights could be nan!


I recommend doing torch.max(your_tensor) and torch.min(your_tensor) to check if any of your tensor is producing “inf”

One of your features is probably a very high range value that even after standardized can have underflow or overflow issues , so during your batches it might see a very low value for that feature and adjust a really high weight to that feature and then suddenly some datapoint has a high value and it explodes. If there is one nan in your predictions, your loss turns to nan. it won’t train anymore or update. You can circumvent that in a loss function but that weight will remain high. Delete those unnecessary features that have a really high range for distribution. Scaling or normalizing them might not help

1 Like

Definitely check the weights. I once checked the weights and some of them were 0.0. I added 0.01 to all of them and then it started training.

  1. check if you NAN values your dataset
  2. dont forget to normalize your data

if you use
or any zero-against function
make sure there is no 0 value,
so add a small number is a way to enhance numerical stability
torch.sqrt(x + 1e-8)