# How to efficiently compute a pairwise kl divergence matrix of a batch of Gaussian distributions?

Hi! I have a `mu` of shape `[B, H]`, and a `std` of shape `[B, H]`.
Each row of the `mu` and `std` means parameters of a Gaussian distribution. Is there any efficient way to compute pairwise kl divergence matrix of shape `[B, B]`?

codes to calculate kl divergence of 2 Gaussian distributions that can broadcast on batch dims:

``````        k = mu_q.size(1)
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff)
logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=1)
logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=1)
fs = torch.sum(torch.div(std_q**2, std_p**2), dim=1) + torch.sum(
torch.div(mu_diff_sq, std_p**2), dim=1
)
kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
``````

If you do the computation on `mu_p = mu[:, None, :]` and `mu_q = mu[None, :, :]`, similarly for `std` and replace the `dim=1` by `dim=-1` (i.e. the last dimension), you should be good to go thanks to broadcasting. Obviously, it would be even neater if we could avoid materializing [B, B, H] matrices, but that is much harder (and might be done by torch.compile or similar automatically now or in the future).

Best regards

Thomas

1 Like

It is surprisingly faster than `mu @ mu.T`.

benchmark codes:

``````import torch
import time

def kl_div(mu_q, std_q, mu_p, std_p):
"""Computes the KL divergence between the two given variational distribution.\
This computes KL(q||p), which is not symmetric. It quantifies how far is\
The estimated distribution q from the true distribution of p."""
k = mu_q.size(-1)
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff)
logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1)
logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1)
fs = torch.sum(torch.div(std_q**2, std_p**2), dim=-1) + torch.sum(
torch.div(mu_diff_sq, std_p**2), dim=-1
)
kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
return kl_divergence

if __name__ == "__main__":
B = 100
D = 256
times = 100
mu = torch.randn(B, D).cuda()
std = torch.randn(B, D).cuda()
start = time.time()
for i in range(times):
kl = kl_div(mu[:, None, :], std[:, None, :], mu[None, :, :], std[None, :, :])
end = time.time()
print(end - start, kl.shape)

start = time.time()
for i in range(times):
kl = mu @ mu.T
end = time.time()
print(end - start, kl.shape)

start = time.time()
for i in range(times):
kl = (mu[:, None, :] * mu[None, :, :]).sum(dim=-1)
end = time.time()
print(end - start, kl.shape)
``````

outputs:

``````0.008843183517456055 torch.Size([100, 100])
0.4114084243774414 torch.Size([100, 100])
0.0010275840759277344 torch.Size([100, 100])
``````

I think you want `time.perf_counter()` and `torch.cuda.synchronize()` before every time taking.
(I have this in my training materials somewhere, but here is an old blog post discussing that in â€śWhen timing CUDA modelsâ€ť: Lernapparat - Machine Learning if you ignore the other bits.)

Best regards

Thomas

I have put the `perf_counter` and `synchronize` in my code, but the `kl_div` is still faster than `mu @ mu.T`.

``````import torch
import time

def kl_div(mu_q, std_q, mu_p, std_p):
"""Computes the KL divergence between the two given variational distribution.\
This computes KL(q||p), which is not symmetric. It quantifies how far is\
The estimated distribution q from the true distribution of p."""
k = mu_q.size(-1)
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff)
logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1)
logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1)
fs = torch.sum(torch.div(std_q**2, std_p**2), dim=-1) + torch.sum(
torch.div(mu_diff_sq, std_p**2), dim=-1
)
kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
return kl_divergence

if __name__ == "__main__":
B = 100
D = 256
times = 1
mu = torch.randn(B, D).cuda()
std = torch.randn(B, D).cuda()
torch.cuda.synchronize()

start = time.perf_counter()
for i in range(times):
kl = kl_div(mu.unsqueeze(1), std.unsqueeze(1), mu.unsqueeze(0), std.unsqueeze(0))
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)

start = time.perf_counter()
for i in range(times):
kl = kl_div(mu.view(B, 1, D), std.view(B, 1, D), mu.view(1, B, D), std.view(1, B, D))
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)

start = time.perf_counter()
for i in range(times):
kl = kl_div(mu[:, None, :], std[:, None, :], mu[None, :, :], std[None, :, :])
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)

start = time.perf_counter()
for i in range(times):
kl1 = []
for m, s in zip(mu, std):
kl1.append(kl_div(m.unsqueeze(0).expand_as(mu), s.unsqueeze(0).expand_as(std), mu, std))
kl1 = torch.stack(kl1)
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl1.shape)

print(kl)
print(kl1)
assert ((kl - kl1).abs() < 1e-10).all()

start = time.perf_counter()
for i in range(times):
kl = mu @ mu.T
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)

start = time.perf_counter()
for i in range(times):
kl1 = (mu[:, None, :] * mu[None, :, :]).sum(dim=-1)
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl1.shape)

print(kl)
print(kl1)
assert ((kl - kl1).abs() < 1e-4).all()

start = time.perf_counter()
for i in range(times):
kl = mu.T
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)

start = time.perf_counter()
muT = mu.T
for i in range(times):
kl = mu @ muT
torch.cuda.synchronize()
end = time.perf_counter()
print(end - start, kl.shape)
``````

outputs:

``````0.000576397986151278 torch.Size([100, 100])
0.00018305005505681038 torch.Size([100, 100])
0.00018039101269096136 torch.Size([100, 100])
0.0067256499314680696 torch.Size([100, 100])
tensor([[     0.0000,  57080.6406,  66168.9922,  ...,  53879.9414,
368704.2188,  79685.1484],
[ 39291.7500,      0.0000, 107659.6328,  ..., 142458.5469,
399184.1562,  47731.7344],
[ 56672.7148,  88284.0703,      0.0000,  ...,  73333.1875,
956130.2500,  51419.0156],
...,
[ 79648.5938,  60400.5156,  77861.4609,  ...,      0.0000,
243438.4062, 108050.6641],
[151945.2969,  58202.5508,  33863.8906,  ...,  31096.5391,
0.0000, 103794.0938],
[ 50227.0508,  78881.3125,  76599.9531,  ...,  58455.1133,
310935.7500,      0.0000]], device='cuda:0')
tensor([[     0.0000,  57080.6406,  66168.9922,  ...,  53879.9414,
368704.2188,  79685.1484],
[ 39291.7500,      0.0000, 107659.6328,  ..., 142458.5469,
399184.1562,  47731.7344],
[ 56672.7148,  88284.0703,      0.0000,  ...,  73333.1875,
956130.2500,  51419.0156],
...,
[ 79648.5938,  60400.5156,  77861.4609,  ...,      0.0000,
243438.4062, 108050.6641],
[151945.2969,  58202.5508,  33863.8906,  ...,  31096.5391,
0.0000, 103794.0938],
[ 50227.0508,  78881.3125,  76599.9531,  ...,  58455.1133,
310935.7500,      0.0000]], device='cuda:0')
0.39876252599060535 torch.Size([100, 100])
0.00013514305464923382 torch.Size([100, 100])
tensor([[286.4470,  12.7784, -24.8256,  ..., -16.4116,  -4.7319,  22.7014],
[ 12.7784, 247.9196,  -5.5962,  ..., -26.5139,  17.1346,   7.6810],
[-24.8256,  -5.5962, 294.9920,  ..., -31.1086,   8.4197,  -2.6804],
...,
[-16.4116, -26.5139, -31.1086,  ..., 252.4709,  39.1322, -24.4404],
[ -4.7319,  17.1346,   8.4197,  ...,  39.1322, 250.0481, -23.8120],
[ 22.7014,   7.6810,  -2.6804,  ..., -24.4404, -23.8120, 245.9375]],
device='cuda:0')
tensor([[286.4470,  12.7784, -24.8256,  ..., -16.4116,  -4.7319,  22.7014],
[ 12.7784, 247.9196,  -5.5962,  ..., -26.5139,  17.1346,   7.6810],
[-24.8256,  -5.5962, 294.9920,  ..., -31.1086,   8.4197,  -2.6804],
...,
[-16.4116, -26.5139, -31.1086,  ..., 252.4708,  39.1322, -24.4404],
[ -4.7319,  17.1346,   8.4197,  ...,  39.1322, 250.0481, -23.8120],
[ 22.7014,   7.6810,  -2.6804,  ..., -24.4404, -23.8120, 245.9375]],
device='cuda:0')
1.8235063180327415e-05 torch.Size([256, 100])
3.278802614659071e-05 torch.Size([100, 100])
``````

time consumed:
`mu @ muT` < `(mu[:, None, :] * mu[None, :, :]).sum(dim=-1)` < `kl_div` < `mu @ mu.T`

Yeah, but now you can trust your numbers more.

1 Like