Getting NAN values for the KL divergence between Beta and Kumaraswamy distributions

Hi everyone,

I have a variational autoencoder architecture and I use stick-breaking prior . One KL divergence component of my model is the KL term between Kumaraswamy and Beta distribution. However, this term becomes NAN values. The original function I used was the following

local_device = torch.device('cuda')
SMALL = torch.tensor(1e-10, dtype=torch.float64, device=local_device)
EULER_GAMMA = torch.tensor(0.5772156649015329, dtype=torch.float, device=local_device)
upper_limit = 10000.0

def beta_fn(a,b):
    return torch.exp(torch.lgamma(torch.tensor(a+SMALL , dtype=torch.float64, device=local_device)) + torch.lgamma(torch.tensor(b+SMALL , dtype=torch.float64, device=local_device)) - torch.lgamma(torch.tensor(a+b+SMALL , dtype=torch.float64, device=local_device)))

def compute_kumar2beta_kld(a, b, alpha, beta):
    ab    = torch.mul(a,b)+ SMALL
    a_inv = torch.pow(a + SMALL, -1)
    b_inv = torch.pow(b + SMALL, -1)
    # compute taylor expansion for E[log (1-v)] term
    kl = torch.mul(torch.pow(1+ab,-1), beta_fn(a_inv, b))
    for idx in range(10):
        kl += torch.mul(torch.pow(idx+2+ab,-1), beta_fn(torch.mul(idx+2., a_inv), b))
    kl = torch.mul(torch.mul(beta-1,b), kl)
    #
    #psi_b = torch.log(b + SMALL) - 1. / (2 * b + SMALL) -\
    #    1. / (12 * b**2 + SMALL)
    psi_b = torch.digamma(b+SMALL)
    kl += torch.mul(torch.div(a-alpha,a+SMALL), -EULER_GAMMA - psi_b - b_inv)
    # add normalization constants
    kl += torch.log(ab) + torch.log(beta_fn(alpha, beta) + SMALL)
    #  final term
    kl += torch.div(-(b-1),b +SMALL)
    return kl

with this error:
/opt/conda/conda-bld/pytorch_1631630815121/work/aten/src/ATen/native/cuda/Loss.cu:111: operator(): block: [259,0,0], thread: [30,0,0] Assertion input_val >= zero && input_val <= one failed.
epoch 0 --- iteration 0: , kumar2beta KL = nan

Then I modified it to the following function to resolve the issue

def compute_kumar2beta_kld(a, b, alpha, beta):

    ab    = torch.mul(a,b)
    a_inv = torch.reciprocal(a )
    b_inv = torch.reciprocal(b )
  
    log_taylor = torch.logsumexp(torch.stack([beta_fn(torch.mul(m , a_inv), b) - torch.log(m + torch.mul(a ,b)) for m in range(1, 10 + 1)], dim=-1), dim=-1)
    kl = torch.mul(torch.mul((beta - 1) , b) , torch.exp(log_taylor))
    psi_b = torch.digamma(b+SMALL)
    kl   += torch.mul(torch.div(a-alpha, torch.clamp(a, SMALL, upper_limit)), -EULER_GAMMA - psi_b - b_inv)    
    kl   += torch.log(torch.clamp(ab, SMALL, upper_limit)) + torch.log(torch.clamp(beta_fn(alpha, beta), SMALL, upper_limit))    
    kl   += torch.div(-(b-1),torch.clamp(b , SMALL, upper_limit))
    return  torch.clamp(kl, min=0.)

But I still get the same error. I will appreciate if someone could suggest how I can fix this function to resolve the NAN value error problem and I’d like to not add too much bias to this KL term.
Thanks