Hi! I have a mu of shape [B, H], and a std of shape [B, H].
Each row of the mu and std means parameters of a Gaussian distribution. Is there any efficient way to compute pairwise kl divergence matrix of shape [B, B]?
codes to calculate kl divergence of 2 Gaussian distributions that can broadcast on batch dims:
If you do the computation on mu_p = mu[:, None, :] and mu_q = mu[None, :, :], similarly for std and replace the dim=1 by dim=-1 (i.e. the last dimension), you should be good to go thanks to broadcasting. Obviously, it would be even neater if we could avoid materializing [B, B, H] matrices, but that is much harder (and might be done by torch.compile or similar automatically now or in the future).
import torch
import time
def kl_div(mu_q, std_q, mu_p, std_p):
"""Computes the KL divergence between the two given variational distribution.\
This computes KL(q||p), which is not symmetric. It quantifies how far is\
The estimated distribution q from the true distribution of p."""
k = mu_q.size(-1)
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff)
logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1)
logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1)
fs = torch.sum(torch.div(std_q**2, std_p**2), dim=-1) + torch.sum(
torch.div(mu_diff_sq, std_p**2), dim=-1
)
kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
return kl_divergence
if __name__ == "__main__":
B = 100
D = 256
times = 100
mu = torch.randn(B, D).cuda()
std = torch.randn(B, D).cuda()
start = time.time()
for i in range(times):
kl = kl_div(mu[:, None, :], std[:, None, :], mu[None, :, :], std[None, :, :])
end = time.time()
print(end - start, kl.shape)
start = time.time()
for i in range(times):
kl = mu @ mu.T
end = time.time()
print(end - start, kl.shape)
start = time.time()
for i in range(times):
kl = (mu[:, None, :] * mu[None, :, :]).sum(dim=-1)
end = time.time()
print(end - start, kl.shape)
I think you want time.perf_counter() and torch.cuda.synchronize() before every time taking.
(I have this in my training materials somewhere, but here is an old blog post discussing that in “When timing CUDA models”: Lernapparat - Machine Learning if you ignore the other bits.)