While testing with torch.autograd.detect_anomaly(check_nan=True)
I get Error during backward pass: Function 'CumprodBackward0' returned nan values in its 0th output.
My model uses a version of Gauss hypergeometric function 2F1, but I can’t figure out where I’m getting the NaN’s since I clamp all the outputs. I know it related to the cumprod
but shouldn’t the clamp fix that issue?
Here’s the function in question, it’s part of a larger model, but I think this is the relevant piece.
def hyp2f1(
a: Tensor,
b: Tensor,
c: Tensor,
z: Tensor
) -> Tensor:
# Initialize variables
eps_min = torch.tensor(1e-6, device=z.device)
eps_max = torch.tensor(1e30, device=z.device)
max_iter = 151
n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
terms = torch.ones((z.size(0), max_iter), device=z.device)
an = a.unsqueeze(-1) + n - 1
bn = b.unsqueeze(-1) + n - 1
cn = (c.unsqueeze(-1) + n - 1) * n + eps_min
terms[:, 1:] = an * bn * z.unsqueeze(-1) / cn
terms = terms.cumprod(dim=-1).clamp(min=-eps_max, max=eps_max)
tot = terms.sum(dim=-1).clamp(min=-eps_max, max=eps_max)
tot = torch.nan_to_num(tot, nan=0.0, neginf=0.0, posinf=eps_max)
return tot
It sounds like cumprod is trying to multiply a inf with 0 at some point?
clamp does not actually get rid of nans
>>> a = torch.tensor([float("nan")]).clamp(min=-1, max=1)
>>> a
tensor([nan])
Does clamp not remove the infinities? I get that the cumprod
will produce an inf, I just want to use a large value then.
I’ve tried rewriting to ignore the infinities and I still get the same issue.
def hyp2f1(a: Tensor, b: Tensor, c: Tensor, z: Tensor) -> Tensor:
eps_min = torch.tensor(1e-30, device=z.device)
eps_max = torch.tensor(1e30, device=z.device)
max_iter = 151
n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
terms = torch.zeros((z.size(0), max_iter), device=z.device)
an = a.unsqueeze(-1) + n - 1
bn = b.unsqueeze(-1) + n - 1
cn = (c.unsqueeze(-1) + n - 1) * n
tot = torch.zeros_like(z)
# c > 0
c_mask = c > 0
terms[c_mask, 1:] = (
torch.log(an[c_mask])
+ torch.log(bn[c_mask])
+ torch.log(z[c_mask].unsqueeze(-1))
- torch.log(cn[c_mask])
)
cumsum = terms[c_mask].cumsum(dim=-1).exp().cumsum(dim=-1)
last_valid = torch.where(torch.isfinite(cumsum), cumsum, torch.zeros_like(cumsum)).argmax(dim=-1)
tot[c_mask] = torch.gather(cumsum, 1, last_valid.unsqueeze(1)).squeeze(1)
# c <= 0
terms[~c_mask, 1:] = (
an[~c_mask]
* bn[~c_mask]
* z[~c_mask].unsqueeze(-1)
/ cn[~c_mask]
)
terms[~c_mask, 0] = 1
tot[~c_mask] = terms[~c_mask].cumprod(dim=-1).sum(dim=-1)
return tot
If you the issue is inf * 0 somewhere inside cumprod, you’d need to clamp before cumprod
Thanks for following up.
I can’t find any place where inf * 0. Once it goes to inf, it stays there.
In the 2nd version I wrote, I even ignore the locations where it goes to inf, but I still see the issue. I’ve even tried using torch.where
to remove the infinities before I use exp, but I get the same issue.
I’ve looked through my forward pass and I just can’t find any NaN or any inf that propagates forward. Every time, it becomes to inf, I clamp it or change it with a where.
If we’re observing:
- the input to cumprod has no infs, nans
- the output of cumprod has infs
The nan during backward could be occuring because the grad output of contains some zeros
>>> a = torch.tensor([3.4*10**38, 2], requires_grad=True)
>>> b = a.cumprod(dim=0)
>>> b
tensor([3.4000e+38, inf], grad_fn=<CumprodBackward0>)
>>> c = b[0]
>>> c.backward()
>>> a.grad
tensor([nan, nan])
We are sure the gradient has zeros in the case (as in the above) where you mask out an element.
I think I got it.
In my function, I use the log space and before I exponentiate, I use torch.where
to replace values that will cause NaNs with 0. Here’s my code:
def hyp2f1(a: Tensor, b: Tensor, c: Tensor, z: Tensor) -> Tensor:
eps_min = torch.tensor(1e-30, device=z.device)
eps_max = torch.tensor(1e30, device=z.device)
max_iter = 151
n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
terms = torch.zeros((z.size(0), max_iter), device=z.device)
an = a.unsqueeze(-1) + n - 1
bn = b.unsqueeze(-1) + n - 1
cn = (c.unsqueeze(-1) + n - 1) * n
tot = torch.zeros_like(z)
# c > 0
c_mask = c > 0
terms[c_mask, 1:] = (
torch.log(an[c_mask])
+ torch.log(bn[c_mask])
+ torch.log(z[c_mask].unsqueeze(-1))
- torch.log(cn[c_mask])
)
cumsum = terms[c_mask].cumsum(dim=-1)
cumsum = torch.where(cumsum > 87, 0, cumsum).exp().cumsum(dim=-1)
last_valid = torch.where(torch.isfinite(cumsum), cumsum, torch.zeros_like(cumsum)).argmax(dim=-1)
tot[c_mask] = torch.gather(cumsum, 1, last_valid.unsqueeze(1)).squeeze(1)
# c <= 0
terms[~c_mask, 1:] = (
an[~c_mask]
* bn[~c_mask]
* z[~c_mask].unsqueeze(-1)
/ cn[~c_mask]
)
terms[~c_mask, 0] = 1
tot[~c_mask] = terms[~c_mask].cumprod(dim=-1).sum(dim=-1)
return tot
Thank you for the help!