I am working on this paper link to paper. I want to implement the KL-Divergence between two generalized Dirichlet distributions in equation 11 of this paper, see screenshot below:
Alpha_1 and Beta_1 are estimated from a decoder network, for example:
decoder_alpha = torch.tensor([[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])
decoder_beta = torch.tensor([[0.3, 0.6, 0.4, 0.8],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])
Also, Alpha_2 and Beta_2 are priors:
prior_alpha = torch.tensor([[0.1, 0.1, 0.4, 0.1],[0.8, 0.7, 0.1, 0.4],[0.2, 0.8, 0.9, 0.1], [0.1, 0.5, 0.2, 0.4], [0.1, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.3]])
prior_beta = torch.tensor([[0.7, 0.6, 0.1, 0.2],[0.5, 0.8, 0.1, 0.2],[0.2, 0.8, 0.5, 0.4], [0.2, 0.6, 0.1, 0.4], [0.6, 0.8, 0.3, 0.2], [0.2, 0.6, 0.3, 0.9]])
Here is my implementation:
decoderParamSum = decoder_alpha + decoder_beta
priorParamSum = prior_alpha + prior_beta
alphaParamsDiff = decoder_alpha - prior_alpha
numerator = torch.lgamma(decoderParamSum) + torch.lgamma(prior_alpha) + torch.lgamma(prior_beta)
denomirator = torch.lgamma(decoder_alpha) + torch.lgamma(decoder_beta) + torch.lgamma(priorParamSum)
firstTerm = torch.sum((numerator - denomirator),dim=1)
secondTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
secondTerm = torch.reshape(secondTerm, (input.shape[0], 1))
secondTerm = torch.digamma(decoder_alpha) - torch.digamma(decoder_beta) - secondTerm
secondTerm = torch.sum(torch.multiply(alphaParamsDiff, secondTerm), dim=1)
thirdTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
thirdTerm = torch.reshape(thirdTerm,(input.shape[0], 1))
v1 = torch.cat([decoder_beta[:,:-1] -decoder_alpha[:, 1:] - decoder_beta[:, 1:], decoder_beta[:, -1:] - 1], dim=-1)
v2 = torch.cat([prior_beta[:,:-1] -prior_alpha[:, 1:] - prior_beta[:, 1:], prior_beta[:, -1:] - 1], dim=-1)
thirdTerm = torch.sum((torch.multiply((v1-v2), thirdTerm)), dim=1)
KLD = firstTerm - secondTerm + thirdTerm
However, the loss values I’m getting are far from my expectation, with large negative values, and not stable. So, I’m guessing that there is a problem with my implementation.
Can anyone please check out my implementation of the KL-divergence or if there is a python implementation already in existence (I have surfed the internet, but nothing coming up). Thank you in advance.