I’m using PyTorch for physics simulations and I’m facing the problem that the output of the simulation is NaN. I checked all my inputs (the parameters of the simulations,
requires_grad) and they are all non-NaN.
So I’d like to traverse the graph backward, inspecting the value of each node, in order to find out at which stage the NaN was introduced. I know I can use
x.grad_fn.next_functions to step back in the graph but I don’t see a way of getting the value of the forward pass at the stage. For example in my case:
(Pdb) p cost tensor(nan, grad_fn=<AddBackward0>) (Pdb) p cost.grad_fn.next_functions ((<MseLossBackward object at 0x7f4a542caad0>, 0), (<AddBackward0 object at 0x7f4a542cab50>, 0)) (Pdb) p dir(cost.grad_fn.next_functions) ['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_register_hook_dict', 'metadata', 'name', 'next_functions', 'register_hook', 'requires_grad']
cost.grad_fn.next_functions node how can I get the value of the forward pass?
- I’m aware of this thread but since my simulation involves hundreds of stages it’s unfeasible to add a NaN check everywhere in the source code. Ideally I don’t need to modify the source code at all, just inspect the resulting objects.
- It looks like
torch.autograd.detect_anomalycan be used to detect NaNs during the backward pass however in my case the NaN is created in the forward pass.