Tracking down NaN gradients

I have noticed that there are NaNs in the gradients of my model. This is confirmed by torch.autograd.detect_anomaly():

RuntimeError: Function 'DivBackward0' returned nan values in its 1th output.

I do not know which division causes the problem since DivBackward0 does not seem to be a unique name. However, I have added asserts to all divisions (like assert torch.all(divisor != 0)) and also have lots of asserts to check for NaNs in general (like assert torch.all(~torch.isnan(t))).

I also iterate the graph and register hooks which print the function and check for NaNs with the following code:

def iter_graph(root, callback):
    queue = [root]
    seen = set()
    while queue:
        fn = queue.pop()
        if fn in seen:
            continue
        seen.add(fn)
        for next_fn, _ in fn.next_functions:
            if next_fn is not None:
                queue.append(next_fn)
        callback(fn)

def register_hooks(var):
    fn_dict = {}
    def hook_cb(fn):
        def register_grad(grad_input, grad_output):
            print(fn)
            assert all(t is None or torch.all(~torch.isnan(t)) for t in grad_input), f"{fn} grad_input={grad_input} grad_output={grad_output}"
            assert all(t is None or torch.all(~torch.isnan(t)) for t in grad_output), f"{fn} grad_input={grad_input} grad_output={grad_output}"
            
            fn_dict[fn] = grad_input
        fn.register_hook(register_grad)
    iter_graph(var.grad_fn, hook_cb)

The output looks like this:

<ViewBackward object at 0x7fb79bae50d0>
<SubBackward0 object at 0x7fb79bae5130>
<DivBackward0 object at 0x7fb79bae51c0>
<DivBackward0 object at 0x7fb79bae5190>
<SliceBackward object at 0x7fb79bae50a0>
<SliceBackward object at 0x7fb79bae5400>
<ViewBackward object at 0x7fb79badcfd0>
...
<SigmoidBackward object at 0x7fb79bacc3d0>
<AddmmBackward object at 0x7fb79bacc430>
<TBackward object at 0x7fb79bacc4c0>
<CudnnBatchNormBackward object at 0x7fb79bacc490>
...
<torch.autograd.function.BilinearInterpolationBackward object at 0x7fb8cc79c4a0>
<torch.autograd.function.BilinearInterpolationBackward object at 0x7fb8cc79c3c0>
<torch.autograd.function.BilinearInterpolationBackward object at 0x7fb8cc79c2e0>

And then it fails with:

AssertionError: <torch.autograd.function.BilinearInterpolationBackward object at 0x7fb8cc79c2e0> grad_input=(tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [nan],
        [0.],
        [0.],
        [0.],
        [nan]], device='cuda:0'), tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [nan],
        [0.],
        [0.],
        [0.],
        [nan]], device='cuda:0'), None, None, None, None, None, None, None, None, None) grad_output=(tensor([[ 0.0000e+00, -7.8456e+29],
        [ 0.0000e+00,  2.4914e+31],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.4474e+30],
        [ 0.0000e+00,  5.9677e+30],
        [        nan,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  9.7542e+30],
        [ 0.0000e+00, -2.9419e+30],
        [        nan,  0.0000e+00]], device='cuda:0'),)

Interestingly, it does not fail immediately after DivBackward0. BilinearInterpolationBackward has NaNs in both the inputs and the output, which means that it does not cause the problem either.

I am pretty lost at this point. What else can I do to track down the NaN gradients?

Edit:

  • Checking for Inf does not help
  • Bigger batch size does not help

Edit 2:

If I disable cuDNN with torch.backends.cudnn.enabled = False, then I get infinity in a MulBackward0. Investigating further.

Hi,

Didn’t the second stack trace when the anomaly mode is enabled show you where the bad div is?

No, the stack trace is useless in this case. The exception is thrown in C++. Also, by now I am pretty sure that div isn’t the problem here. I guess detect_anomaly does not work well with some cudnn optimizations.

1 Like

Can you still share the two stack traces?

This is the one with cudnn disabled:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-3906501279d7> in <module>
     21 
     22         # get_dot = register_hooks(loss)
---> 23         loss.backward()
     24         # dot = get_dot()
     25 

venv/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

venv/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     95         retain_graph = create_graph
     96 
---> 97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
     99         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Function 'SelectBackward' returned nan values in its 0th output.

And with cudnn:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-3906501279d7> in <module>
     21 
     22         # get_dot = register_hooks(loss)
---> 23         loss.backward()
     24         # dot = get_dot()
     25 

venv/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

venv/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     95         retain_graph = create_graph
     96 
---> 97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
     99         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Function 'DivBackward0' returned nan values in its 1th output.

Isn’t there a second stack trace just above this one when anomaly mode is enabled?

No. Should there be one? I am using PyTorch 1.4.

For example in the code below, you can see the warning that points to the faulty forward part:

import torch
from torch import autograd
autograd.set_detect_anomaly(True)

return_nan = True

class MyBadFn(autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        ctx.save_for_backward(inp)
        return inp ** 2

    @staticmethod
    def backward(ctx, grad_out):
        inp, = ctx.saved_tensors
        res = 2 * grad_out * inp
        if return_nan:
            res = grad_out / 0 * 0
        return res


inp = torch.rand(10, requires_grad=True)

out = MyBadFn.apply(inp)

out.sum().backward()

prints

[W python_anomaly_mode.cpp:60] Warning: Error detected in MyBadFnBackward. Traceback of forward call that caused the error:
  File "foo.py", line 24, in <module>
    out = MyBadFn.apply(inp)
 (function print_stack)
Traceback (most recent call last):
  File "foo.py", line 26, in <module>
    out.sum().backward()
  File "/Users/albandes/workspace/pytorch_dev/torch/tensor.py", line 184, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/Users/albandes/workspace/pytorch_dev/torch/autograd/__init__.py", line 107, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'MyBadFnBackward' returned nan values in its 0th output.

For me this looks different:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-c48dcde9ba4c> in <module>
     24 out = MyBadFn.apply(inp)
     25 
---> 26 out.sum().backward()

venv/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

venv/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     95         retain_graph = create_graph
     96 
---> 97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
     99         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Function 'MyBadFnBackward' returned nan values in its 0th output.

Regarding my “real” problem: I seem to have gradient explosion and am now trying to figure out why.

Edit: I think my Jupyter Notebook is messing with the errors. When I run this in a normal Python file, I get the following instead:

venv/bin/python test.py
Warning: Traceback of forward call that caused the error:
  File "test.py", line 24, in <module>
    out = MyBadFn.apply(inp)
 (print_stack at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "test.py", line 26, in <module>
    out.sum().backward()
  File "venv/lib/python3.8/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 97, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'MyBadFnBackward' returned nan values in its 0th output.

Ho :confused: Is the Jupyter Notebook silencing warnings?? (because the second stack trace is actually a python warning), what happens if you do: import warning; warnings.simplefilter("always") in the notebook to try and make all warnings appear?

This does not help either. In case you want to reproduce:
Python 3.8.2
torch==1.4.0

$ jupyter --version
jupyter core     : 4.6.3
jupyter-notebook : 6.0.3
qtconsole        : 4.7.2
ipython          : 7.13.0
ipykernel        : 5.2.0
jupyter client   : 6.1.2
jupyter lab      : 2.0.1
nbconvert        : 5.6.1
ipywidgets       : 7.5.1
nbformat         : 5.0.5
traitlets        : 4.3.3

I have tired also some other stuff and did not get it working.

Interesting, I can actually repro on colab: https://colab.research.google.com/drive/1gbUpNxLMlL4H-mKsgTh7UiU_JTSPHnvZ
I’ll open an issue to check that.

Going back to your original issue, if you can run the same code in a scirpt outside of the notebook, that might help you pinpoint things.
Otherwise, if you can make a simple repro, I can take a closer look as well.

Thanks for your help and creating the issue. I have been able to solve the problem since. I really had a gradient explosion, which caused Infs (which were turned into NaNs by cudnn).

To mitigate the problem, I am now clipping the gradients in a backward of a torch.autograd.Function (register_hook would work as well). I cannot use normal gradient clipping, since it happens immediately in the first iteration.

3 Likes

Hello @joel @albanD,

I am also encountering the same problem in my code. In my case, MSE loss is nan. I tried clipping gradients by adding below lines

torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1.0, norm_type=2.0)
optimizer.step()
optimizer.zero_grad()

But I still get nan, would you be able to share what actually you modified?

i get the same error too, I simplify my net to one Linear, and remove the activation function, and this error disappeared, but I got another error:Function ‘AddmmBackward’ returned nan values in its 1th output.

The notebook doesn’t seem to exist anymore, @albanD - do you have another version for the same?

The notebook was just running the code in the message above and showing that the warning was silenced.

Can you please share details about your approach? I am also facing the same problem and normal gradient clipping is of no use.