Code debugging: How to implement Generalized Dirichlet distributions KL-Divergence in python?

I am working on this paper link to paper. I want to implement the KL-Divergence between two generalized Dirichlet distributions in equation 11 of this paper, see screenshot below:

enter image description here
enter image description here

Alpha_1 and Beta_1 are estimated from a decoder network, for example:

decoder_alpha = torch.tensor([[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])

decoder_beta = torch.tensor([[0.3, 0.6, 0.4, 0.8],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])

Also, Alpha_2 and Beta_2 are priors:

prior_alpha = torch.tensor([[0.1, 0.1, 0.4, 0.1],[0.8, 0.7, 0.1, 0.4],[0.2, 0.8, 0.9, 0.1], [0.1, 0.5, 0.2, 0.4], [0.1, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.3]])

prior_beta = torch.tensor([[0.7, 0.6, 0.1, 0.2],[0.5, 0.8, 0.1, 0.2],[0.2, 0.8, 0.5, 0.4], [0.2, 0.6, 0.1, 0.4], [0.6, 0.8, 0.3, 0.2], [0.2, 0.6, 0.3, 0.9]])

Here is my implementation:

        decoderParamSum = decoder_alpha + decoder_beta
        priorParamSum = prior_alpha + prior_beta
        alphaParamsDiff = decoder_alpha - prior_alpha
        numerator = torch.lgamma(decoderParamSum) + torch.lgamma(prior_alpha) + torch.lgamma(prior_beta)
        denomirator = torch.lgamma(decoder_alpha) + torch.lgamma(decoder_beta) + torch.lgamma(priorParamSum)
        firstTerm = torch.sum((numerator - denomirator),dim=1)
        secondTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
        secondTerm = torch.reshape(secondTerm, (input.shape[0], 1))
        secondTerm = torch.digamma(decoder_alpha) - torch.digamma(decoder_beta) - secondTerm
        secondTerm = torch.sum(torch.multiply(alphaParamsDiff, secondTerm), dim=1)
        thirdTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
        thirdTerm = torch.reshape(thirdTerm,(input.shape[0], 1))
        v1 = torch.cat([decoder_beta[:,:-1] -decoder_alpha[:, 1:] - decoder_beta[:, 1:], decoder_beta[:, -1:] - 1], dim=-1)
        v2 = torch.cat([prior_beta[:,:-1] -prior_alpha[:, 1:] - prior_beta[:, 1:], prior_beta[:, -1:] - 1], dim=-1)
        
        thirdTerm = torch.sum((torch.multiply((v1-v2), thirdTerm)), dim=1)
       
        KLD = firstTerm - secondTerm + thirdTerm

However, the loss values I’m getting are far from my expectation, with large negative values, and not stable. So, I’m guessing that there is a problem with my implementation.

Can anyone please check out my implementation of the KL-divergence or if there is a python implementation already in existence (I have surfed the internet, but nothing coming up). Thank you in advance.