Finding source of NaN in forward pass

I ran into some issues with occasionally getting a NaN during the forward pass which leads to my loss function becoming NaN as well and breaking the training process. Since the graphs I’m generating are quite complex, it’s difficult for me to pin down just what in the forward pass causes the NaNs.
I tried using torch.autograd.detect_anomaly but it only raises an error on the loss function which is unhelpful in my case. Is there a way to also get the errors directly in the forward pass so I can find the actual source of the issue?

1 Like

What output does detect_anomaly yield?
Were you able to isolate the NaN to a few (or a single) iteration?
If so, you could use forward hooks and store temporarily each submodules output in order to track down the source of the first NaN occurrence.

1 Like

detect_anomaly yields RuntimeError: Function 'MseLossBackward' returned nan values in its 0th output. which as I mentioned in my first post isn’t very helpful in this case since the NaNs are already present in the input to the loss function.
There’s usually exactly one NaN in the first batch - interestingly the exact index of where in the batch the NaN occurs (or whether it occurs at all) seems to vary despite my use of torch.manual_seed. Following your advice, I used the following hook for raising a RuntimeError if a NaN is encountered and added it to my models’ submodules:

        def nan_hook(self, inp, output):
            if not isinstance(output, tuple):
                outputs = [output]
            else:
                outputs = output

            for i, out in enumerate(outputs):
                nan_mask = torch.isnan(out)
                if nan_mask.any():
                    print("In", self.__class__.__name__)
                    raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])

        for submodule in model.modules():
            submodule.register_forward_hook(nan_hook)

From what I can tell the NaNs are generated from a call to torch.nn.functional.gumbel_softmax. I tried to check for the offending input as follows:

gumbel_one_hot = functional.gumbel_softmax(raw_attention, hard=True, dim=1)
        
nan_mask = torch.isnan(gumbel_one_hot)
if nan_mask.any():
    print(nan_mask.nonzero())
    indices = nan_mask.nonzero()[:, 0].unique(sorted=True)
    print("Input:", raw_attention[indices])
    print("Output":, gumbel_one_hot[indices])
    raise RuntimeError("NaN encountered in gumbel softmax")

However, it outputs the following which shouldn’t produce NaNs and I can’t reproduce it either by using the same seed on GPU in a separate script:

Input: tensor([[-0.2471, -0.2118, -0.1458, -0.1783, -0.1228, -0.1062, -0.0502,  0.0112,
          0.0097, -0.2189,    -inf,    -inf,  0.2214,  0.2723,  0.0808,  0.1290,
          0.1007,  0.4129,  0.4808,  0.3924,  0.2889,  0.4916,    -inf,    -inf]],
       device='cuda:0', grad_fn=<IndexBackward>)
Output: tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       device='cuda:0', grad_fn=<IndexBackward>)

Is this issue possibly related to https://github.com/pytorch/pytorch/issues/22442?

6 Likes

Really great debugging!
The linked issue seems to only produce the NaNs on a GPU and in PyTorch 1.1.0.
Could you try to update to the nightly build and check if this issue still occurs?

1 Like

From my testing so far upgrading to nightly did finally solved the modules’ NaN issues. Thanks again for the the advice.

1 Like

Hi, I am facing the same problem with NaNs after some epochs, where the error is saying

RuntimeError: Function ‘MseLossBackward’ returned nan values in its 0th output.

Please provide any solution. And how do I use nightly to solve this?. Thanks in advance.

1 Like