Error during backward pass: Function 'CumprodBackward0' returned nan values in its 0th output

While testing with torch.autograd.detect_anomaly(check_nan=True) I get Error during backward pass: Function 'CumprodBackward0' returned nan values in its 0th output. My model uses a version of Gauss hypergeometric function 2F1, but I can’t figure out where I’m getting the NaN’s since I clamp all the outputs. I know it related to the cumprod but shouldn’t the clamp fix that issue?

Here’s the function in question, it’s part of a larger model, but I think this is the relevant piece.

def hyp2f1(
    a: Tensor,
    b: Tensor,
    c: Tensor,
    z: Tensor
) -> Tensor:
    # Initialize variables
    eps_min = torch.tensor(1e-6, device=z.device)
    eps_max = torch.tensor(1e30, device=z.device)
    max_iter = 151

    n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
    terms = torch.ones((z.size(0), max_iter), device=z.device)
    
    an = a.unsqueeze(-1) + n - 1
    bn = b.unsqueeze(-1) + n - 1
    cn = (c.unsqueeze(-1) + n - 1) * n + eps_min

    terms[:, 1:] = an * bn * z.unsqueeze(-1) / cn
    terms = terms.cumprod(dim=-1).clamp(min=-eps_max, max=eps_max)
    
    tot = terms.sum(dim=-1).clamp(min=-eps_max, max=eps_max)
    tot = torch.nan_to_num(tot, nan=0.0, neginf=0.0, posinf=eps_max)
    return tot
  

It sounds like cumprod is trying to multiply a inf with 0 at some point?

clamp does not actually get rid of nans

>>> a = torch.tensor([float("nan")]).clamp(min=-1, max=1)
>>> a
tensor([nan])

Does clamp not remove the infinities? I get that the cumprod will produce an inf, I just want to use a large value then.

I’ve tried rewriting to ignore the infinities and I still get the same issue.

def hyp2f1(a: Tensor, b: Tensor, c: Tensor, z: Tensor) -> Tensor:
    eps_min = torch.tensor(1e-30, device=z.device)
    eps_max = torch.tensor(1e30, device=z.device)
    max_iter = 151

    n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
    terms = torch.zeros((z.size(0), max_iter), device=z.device)
    
    an = a.unsqueeze(-1) + n - 1
    bn = b.unsqueeze(-1) + n - 1
    cn = (c.unsqueeze(-1) + n - 1) * n
    
    tot = torch.zeros_like(z)
    
    # c > 0
    c_mask = c > 0
    terms[c_mask, 1:] = (
        torch.log(an[c_mask])
        + torch.log(bn[c_mask])
        + torch.log(z[c_mask].unsqueeze(-1))
        - torch.log(cn[c_mask])
    )
    cumsum = terms[c_mask].cumsum(dim=-1).exp().cumsum(dim=-1)
    last_valid = torch.where(torch.isfinite(cumsum), cumsum, torch.zeros_like(cumsum)).argmax(dim=-1)
    tot[c_mask] = torch.gather(cumsum, 1, last_valid.unsqueeze(1)).squeeze(1)
    
    # c <= 0
    terms[~c_mask, 1:] = (
        an[~c_mask]
        * bn[~c_mask]
        * z[~c_mask].unsqueeze(-1)
        / cn[~c_mask]
    )
    terms[~c_mask, 0] = 1
    tot[~c_mask] = terms[~c_mask].cumprod(dim=-1).sum(dim=-1)
    return tot

If you the issue is inf * 0 somewhere inside cumprod, you’d need to clamp before cumprod

Thanks for following up.

I can’t find any place where inf * 0. Once it goes to inf, it stays there.

In the 2nd version I wrote, I even ignore the locations where it goes to inf, but I still see the issue. I’ve even tried using torch.where to remove the infinities before I use exp, but I get the same issue.

I’ve looked through my forward pass and I just can’t find any NaN or any inf that propagates forward. Every time, it becomes to inf, I clamp it or change it with a where.

If we’re observing:

  • the input to cumprod has no infs, nans
  • the output of cumprod has infs

The nan during backward could be occuring because the grad output of contains some zeros

>>> a = torch.tensor([3.4*10**38, 2], requires_grad=True)
>>> b = a.cumprod(dim=0)
>>> b
tensor([3.4000e+38,        inf], grad_fn=<CumprodBackward0>)
>>> c = b[0]
>>> c.backward()
>>> a.grad
tensor([nan, nan])

We are sure the gradient has zeros in the case (as in the above) where you mask out an element.

I think I got it.

In my function, I use the log space and before I exponentiate, I use torch.where to replace values that will cause NaNs with 0. Here’s my code:

def hyp2f1(a: Tensor, b: Tensor, c: Tensor, z: Tensor) -> Tensor:
    eps_min = torch.tensor(1e-30, device=z.device)
    eps_max = torch.tensor(1e30, device=z.device)
    max_iter = 151

    n = torch.arange(1, max_iter, device=z.device).unsqueeze(0).expand(z.size(0), -1)
    terms = torch.zeros((z.size(0), max_iter), device=z.device)
    
    an = a.unsqueeze(-1) + n - 1
    bn = b.unsqueeze(-1) + n - 1
    cn = (c.unsqueeze(-1) + n - 1) * n
    
    tot = torch.zeros_like(z)
    
    # c > 0
    c_mask = c > 0
    terms[c_mask, 1:] = (
        torch.log(an[c_mask])
        + torch.log(bn[c_mask])
        + torch.log(z[c_mask].unsqueeze(-1))
        - torch.log(cn[c_mask])
    )
    
    cumsum = terms[c_mask].cumsum(dim=-1)
    cumsum = torch.where(cumsum > 87, 0, cumsum).exp().cumsum(dim=-1)
    last_valid = torch.where(torch.isfinite(cumsum), cumsum, torch.zeros_like(cumsum)).argmax(dim=-1)
    tot[c_mask] = torch.gather(cumsum, 1, last_valid.unsqueeze(1)).squeeze(1)
    
    # c <= 0
    terms[~c_mask, 1:] = (
        an[~c_mask]
        * bn[~c_mask]
        * z[~c_mask].unsqueeze(-1)
        / cn[~c_mask]
    )
    terms[~c_mask, 0] = 1
    tot[~c_mask] = terms[~c_mask].cumprod(dim=-1).sum(dim=-1)
    return tot

Thank you for the help!