NaN in view() operation

I got NaN gradients in view() operation of torchvision.resnet.
How reliable is torch.autograd.detect_anomaly? What are possible causes view() yields NaN gradients?

with torch.autograd.detect_anomaly():
    logits = net(img, gt=labels)
    loss = F.cross_entropy(input=logits, target=labels)
    opt.zero_grad()
    loss.backward()
    opt.step()
RuntimeWarning: Traceback of forward call that caused the error:██████████████████████████████████████████████████████████████████████████████████████████████████▌                                         | 929/1268 [04:09<01:28,  3.82it/s]
  File "main.py", line 447, in <module>
    train()
  File "main.py", line 336, in train
    logits = net(img, gt=labels)
  File "/usr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "main.py", line 145, in forward
    return self.fc(self.net(x), gt)
  File "/usr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/lib/python3.7/site-packages/torchvision/models/resnet.py", line 161, in forward
    x = x.view(x.size(0), -1)

Traceback (Most recent call last):
447 main.py                                                     <module> --> train()
339 main.py                                                     train    --> loss.backward()
96  /usr/lib/python3.7/site-packages/torch/tensor.py            backward --> torch.autograd.backward(self, gradient, retain_graph, create_graph)
90  /usr/lib/python3.7/site-packages/torch/autograd/__init__.py backward --> allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'ViewBackward' returned nan values in its 0th output.
> /usr/lib/python3.7/site-packages/torch/autograd/__init__.py(90)backward()
-> allow_unreachable=True)  # allow_unreachable flag
(Pdb) l
 85         if retain_graph is None:
 86             retain_graph = create_graph
 87
 88         Variable._execution_engine.run_backward(
 89             tensors, grad_tensors, retain_graph, create_graph,
 90  ->         allow_unreachable=True)  # allow_unreachable flag
 91
 92
 93     def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
 94              only_inputs=True, allow_unused=False):
 95         r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
1 Like

Hi,

The detect_anomaly should be fairly reliable.
Could you add this to your code:

logits = net(img, gt=labels)

def hook_fn(grad):
    print("logits grad")
    print(grad)
logits.register_hook(hook_fn)

loss = F.cross_entropy(input=logits, target=labels)

Does the gradients for logits look good?
If so, can you add the same check to your code line 145 and change it from return self.fc(self.net(x), gt) to

resnet_output = self.net(x)

def hook_fn(grad):
    print("resnet output grad")
    print(grad)
resnet_output.register_hook(hook_fn)

return self.fc(resnet_output, gt)