Manually inspect values in the computation graph

I’m using PyTorch for physics simulations and I’m facing the problem that the output of the simulation is NaN. I checked all my inputs (the parameters of the simulations, requires_grad) and they are all non-NaN.

So I’d like to traverse the graph backward, inspecting the value of each node, in order to find out at which stage the NaN was introduced. I know I can use x.grad_fn.next_functions to step back in the graph but I don’t see a way of getting the value of the forward pass at the stage. For example in my case:

(Pdb) p cost
tensor(nan, grad_fn=<AddBackward0>)
(Pdb) p cost.grad_fn.next_functions        
((<MseLossBackward object at 0x7f4a542caad0>, 0), (<AddBackward0 object at 0x7f4a542cab50>, 0))
(Pdb) p dir(cost.grad_fn.next_functions[0][0])
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']

From the cost.grad_fn.next_functions[0][0] node how can I get the value of the forward pass?

  • I’m aware of this thread but since my simulation involves hundreds of stages it’s unfeasible to add a NaN check everywhere in the source code. Ideally I don’t need to modify the source code at all, just inspect the resulting objects.
  • It looks like torch.autograd.detect_anomaly can be used to detect NaNs during the backward pass however in my case the NaN is created in the forward pass.


I am confused: if the nan happens in the forward, why do you want to inspect the backward graph?
By doing binary search on your code, it will be fairly fast to find out where the nan appears no?

I don’t want to inspect the backward graph, I want to inspect the graph backwards, i.e. starting from the final value (which is NaN) and traversing until I find the first non-NaN value, in order to identify the operation that introduced the NaN value. I thought the backward nodes need access to the values of the forward pass in order to compute the local gradient so I though there could be a possibility to access the forward buffers from backward nodes.

I’m not sure I understand what you mean by doing binary search on the code. Could you clarify this point? Thanks.

I’m afraid the graph that we create to compute gradients does not store all the intermediary results (only the ones that are required) and they are not easily accessible to the user in general.

What I mean by binary search is:

  • Take the middle of your code
  • Is the nan already there?
  • If so, do the same thing on the first half
  • else do the same thing on the other half

This will give you the line where it appears (even if you have 1k+ line, in 7 iterations, this will give you the line where the nan appears).