Detect inf/nan in forward pass

autograd.detect_anomaly detects inf/nan in the backward pass.

I want to have the same in the forward pass. With the possibility to whitelist a few special operations, modules or code blocks, e.g. masking attention energies to -inf, etc.

How?

Is there an easy way to install a hook after every operation?

Or otherwise, I was thinking about installing a post forward hook for every module, which would already be helpful.

Or I was also thinking about using sys.settrace to install some function which would inspect the locals.

Or what else would you suggest?

Naively you could use torch.isnan/.isinf/.isfinite inside the forward and add your custom logic to it. Alternatively, forward hooks should also work as these would allow you to return a modified output. However, the issue would be the operation causing the invalid outputs could still return an invalid gradient.

Yes sure but I don’t want to modify the code of the forward pass every time I want to do such checks.

Similarly, I might want to collect other statistics in the future, like e.g. finding very big activations, or so.