Equivalent to numpy's nan_to_num?

Numpy has a nan_to_num function, which replaces nan, inf, and -inf with three arbitrary constants (usually zero, something large, and some large negative number). Is there an equivalent in pytorch? If not, how would I construct my own, without assignment operators that would break the gradient; would I have to assemble something element-by-element and use torch.cat, or is there some other solution? I am fine with the fact that all replaced elements would have a gradient of 0.


I’m currently using the following code, but it’s woefully inefficient:

def nan_to_num(t,mynan=0.):
    if torch.all(torch.isfinite(t)):
        return t
    if len(t.size()) == 0:
        return torch.tensor(mynan)
    return torch.cat([nan_to_num(l).unsqueeze(0) for l in t],0)

Did you find an answer for this?

If you’re using PyTorch’s nightly builds (or wait until PyTorch 1.8) then we’ve added nan_to_num. See its documentation here: