How to efficiently compute a pairwise kl divergence matrix of a batch of Gaussian distributions?

If you do the computation on mu_p = mu[:, None, :] and mu_q = mu[None, :, :], similarly for std and replace the dim=1 by dim=-1 (i.e. the last dimension), you should be good to go thanks to broadcasting. Obviously, it would be even neater if we could avoid materializing [B, B, H] matrices, but that is much harder (and might be done by torch.compile or similar automatically now or in the future).

Best regards

Thomas

1 Like