I’m using PyTorch for physics simulations and I’m facing the problem that the output of the simulation is NaN. I checked all my inputs (the parameters of the simulations, requires_grad
) and they are all non-NaN.
So I’d like to traverse the graph backward, inspecting the value of each node, in order to find out at which stage the NaN was introduced. I know I can use x.grad_fn.next_functions
to step back in the graph but I don’t see a way of getting the value of the forward pass at the stage. For example in my case:
(Pdb) p cost
tensor(nan, grad_fn=<AddBackward0>)
(Pdb) p cost.grad_fn.next_functions
((<MseLossBackward object at 0x7f4a542caad0>, 0), (<AddBackward0 object at 0x7f4a542cab50>, 0))
(Pdb) p dir(cost.grad_fn.next_functions[0][0])
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']
From the cost.grad_fn.next_functions[0][0]
node how can I get the value of the forward pass?
- I’m aware of this thread but since my simulation involves hundreds of stages it’s unfeasible to add a NaN check everywhere in the source code. Ideally I don’t need to modify the source code at all, just inspect the resulting objects.
- It looks like
torch.autograd.detect_anomaly
can be used to detect NaNs during the backward pass however in my case the NaN is created in the forward pass.