Hi everyone, is there any way to efficiently calculate pair-wise KL divergence between 2 sets of samples A (batch_size x dimension) and B (batch_size x dimension), which returns a tensor of (batch_size x batch_size) in Pytorch. It’s similar to torch.cdist but with KL divergence rather than p-norm distance? Thank you so much.
Yes, it can be done using pytorch tensor operations (and without loops),
so it should be efficient. (However, I don’t see anyway to do this using
The idea is that you need get a batch-wise “outer product” and
seems to me to be the most straightforward way to achieve this.
Here we check a loopless implementation against using
>>> import torch >>> torch.__version__ '1.9.0' >>> >>> _ = torch.manual_seed (2021) >>> >>> A = torch.randn (3, 5) >>> B = torch.rand (3, 5) >>> >>> A tensor([[ 2.2871, 0.6413, -0.8615, -0.3649, -0.6931], [ 0.9023, -2.7183, -1.4478, 0.6238, 0.4822], [-2.3055, 0.9176, 1.5674, -0.1284, -1.0042]]) >>> B tensor([[0.8303, 0.5216, 0.7438, 0.8290, 0.0219], [0.0813, 0.0172, 0.1464, 0.7492, 0.9450], [0.6737, 0.1135, 0.7421, 0.7810, 0.9446]]) >>> >>> batch_klA = torch.zeros (3, 3) >>> >>> for i in range (3): ... for j in range (3): ... batch_klA[i, j] = torch.nn.functional.kl_div (A[i], B[j], reduction = 'sum') ... >>> batch_klB = (B * B.log()).sum (dim = 1) - torch.einsum ('ik, jk -> ij', A, B) >>> >>> batch_klA tensor([[-2.2283, 0.0326, -1.0158], [ 0.2646, -1.5627, -1.1489], [-0.5549, 0.1626, 0.3534]]) >>> batch_klB tensor([[-2.2283, 0.0326, -1.0158], [ 0.2646, -1.5627, -1.1489], [-0.5549, 0.1626, 0.3534]]) >>> >>> torch.allclose (batch_klA, batch_klB) True
Hi @KFrank, when I tried the above implementation with A=B, it returned a tensor containing all nans while we need a tensor with a zero-diagonal tensor, I’m still trying to figure out the problem
Please note that both the implementation I posted above, as well as,
by default, pytorch’s
kl_div(), expect the
A in the example
above) to be log-probabilities and the
B) to be probabilities.
If you pass in a probability that is zero or negative, you will get a
If this doesn’t clear up your issue, please post a short, complete, runnable
example that reproduces the issue you see.