Calculate p pair-wise KL divergence

Hi everyone, is there any way to efficiently calculate pair-wise KL divergence between 2 sets of samples A (batch_size x dimension) and B (batch_size x dimension), which returns a tensor of (batch_size x batch_size) in Pytorch. It’s similar to torch.cdist but with KL divergence rather than p-norm distance? Thank you so much.

Hi Hoang!

Yes, it can be done using pytorch tensor operations (and without loops),
so it should be efficient. (However, I don’t see anyway to do this using
pytorch’s kl_div(), itself.)

The idea is that you need get a batch-wise “outer product” and einsum()
seems to me to be the most straightforward way to achieve this.

Here we check a loopless implementation against using kl_div()
with loops:

>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> _ = torch.manual_seed (2021)
>>>
>>> A = torch.randn (3, 5)
>>> B = torch.rand (3, 5)
>>>
>>> A
tensor([[ 2.2871,  0.6413, -0.8615, -0.3649, -0.6931],
        [ 0.9023, -2.7183, -1.4478,  0.6238,  0.4822],
        [-2.3055,  0.9176,  1.5674, -0.1284, -1.0042]])
>>> B
tensor([[0.8303, 0.5216, 0.7438, 0.8290, 0.0219],
        [0.0813, 0.0172, 0.1464, 0.7492, 0.9450],
        [0.6737, 0.1135, 0.7421, 0.7810, 0.9446]])
>>>
>>> batch_klA = torch.zeros (3, 3)
>>>
>>> for  i in range (3):
...     for  j in range (3):
...         batch_klA[i, j] = torch.nn.functional.kl_div (A[i], B[j], reduction = 'sum')
...
>>> batch_klB = (B * B.log()).sum (dim = 1) - torch.einsum ('ik, jk -> ij', A, B)
>>>
>>> batch_klA
tensor([[-2.2283,  0.0326, -1.0158],
        [ 0.2646, -1.5627, -1.1489],
        [-0.5549,  0.1626,  0.3534]])
>>> batch_klB
tensor([[-2.2283,  0.0326, -1.0158],
        [ 0.2646, -1.5627, -1.1489],
        [-0.5549,  0.1626,  0.3534]])
>>>
>>> torch.allclose (batch_klA, batch_klB)
True

Best.

K. Frank

Hi @KFrank, when I tried the above implementation with A=B, it returned a tensor containing all nans while we need a tensor with a zero-diagonal tensor, I’m still trying to figure out the problem

Hi Hoang!

Please note that both the implementation I posted above, as well as,
by default, pytorch’s kl_div(), expect the input (A in the example
above) to be log-probabilities and the target (B) to be probabilities.
If you pass in a probability that is zero or negative, you will get a nan.

If this doesn’t clear up your issue, please post a short, complete, runnable
example that reproduces the issue you see.

Best.

K. Frank