How to efficiently compute a pairwise kl divergence matrix of a batch of Gaussian distributions?

Hi! I have a mu of shape [B, H], and a std of shape [B, H].
Each row of the mu and std means parameters of a Gaussian distribution. Is there any efficient way to compute pairwise kl divergence matrix of shape [B, B]?

codes to calculate kl divergence of 2 Gaussian distributions that can broadcast on batch dims:

        k = mu_q.size(1)
        mu_diff = mu_p - mu_q
        mu_diff_sq = torch.mul(mu_diff, mu_diff)
        logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=1)
        logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=1)
        fs = torch.sum(torch.div(std_q**2, std_p**2), dim=1) + torch.sum(
            torch.div(mu_diff_sq, std_p**2), dim=1
        )
        kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5

If you do the computation on mu_p = mu[:, None, :] and mu_q = mu[None, :, :], similarly for std and replace the dim=1 by dim=-1 (i.e. the last dimension), you should be good to go thanks to broadcasting. Obviously, it would be even neater if we could avoid materializing [B, B, H] matrices, but that is much harder (and might be done by torch.compile or similar automatically now or in the future).

Best regards

Thomas

1 Like

It is surprisingly faster than mu @ mu.T.

benchmark codes:

import torch
import time

def kl_div(mu_q, std_q, mu_p, std_p):
    """Computes the KL divergence between the two given variational distribution.\
        This computes KL(q||p), which is not symmetric. It quantifies how far is\
        The estimated distribution q from the true distribution of p."""
    k = mu_q.size(-1)
    mu_diff = mu_p - mu_q
    mu_diff_sq = torch.mul(mu_diff, mu_diff)
    logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1)
    logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1)
    fs = torch.sum(torch.div(std_q**2, std_p**2), dim=-1) + torch.sum(
        torch.div(mu_diff_sq, std_p**2), dim=-1
    )
    kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
    return kl_divergence

if __name__ == "__main__":
    B = 100
    D = 256
    times = 100
    mu = torch.randn(B, D).cuda()
    std = torch.randn(B, D).cuda()
    start = time.time()
    for i in range(times):
        kl = kl_div(mu[:, None, :], std[:, None, :], mu[None, :, :], std[None, :, :])
    end = time.time()
    print(end - start, kl.shape)

    start = time.time()
    for i in range(times):
        kl = mu @ mu.T
    end = time.time()
    print(end - start, kl.shape)

    start = time.time()
    for i in range(times):
        kl = (mu[:, None, :] * mu[None, :, :]).sum(dim=-1)
    end = time.time()
    print(end - start, kl.shape)

outputs:

0.008843183517456055 torch.Size([100, 100])
0.4114084243774414 torch.Size([100, 100])
0.0010275840759277344 torch.Size([100, 100])

I think you want time.perf_counter() and torch.cuda.synchronize() before every time taking.
(I have this in my training materials somewhere, but here is an old blog post discussing that in “When timing CUDA models”: Lernapparat - Machine Learning if you ignore the other bits.)

Best regards

Thomas

I have put the perf_counter and synchronize in my code, but the kl_div is still faster than mu @ mu.T.

import torch
import time

def kl_div(mu_q, std_q, mu_p, std_p):
    """Computes the KL divergence between the two given variational distribution.\
        This computes KL(q||p), which is not symmetric. It quantifies how far is\
        The estimated distribution q from the true distribution of p."""
    k = mu_q.size(-1)
    mu_diff = mu_p - mu_q
    mu_diff_sq = torch.mul(mu_diff, mu_diff)
    logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1)
    logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1)
    fs = torch.sum(torch.div(std_q**2, std_p**2), dim=-1) + torch.sum(
        torch.div(mu_diff_sq, std_p**2), dim=-1
    )
    kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5
    return kl_divergence

if __name__ == "__main__":
    B = 100
    D = 256
    times = 1
    mu = torch.randn(B, D).cuda()
    std = torch.randn(B, D).cuda()
    torch.cuda.synchronize()

    start = time.perf_counter()
    for i in range(times):
        kl = kl_div(mu.unsqueeze(1), std.unsqueeze(1), mu.unsqueeze(0), std.unsqueeze(0))
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

    start = time.perf_counter()
    for i in range(times):
        kl = kl_div(mu.view(B, 1, D), std.view(B, 1, D), mu.view(1, B, D), std.view(1, B, D))
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

    start = time.perf_counter()
    for i in range(times):
        kl = kl_div(mu[:, None, :], std[:, None, :], mu[None, :, :], std[None, :, :])
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

    start = time.perf_counter()
    for i in range(times):
        kl1 = []
        for m, s in zip(mu, std):
            kl1.append(kl_div(m.unsqueeze(0).expand_as(mu), s.unsqueeze(0).expand_as(std), mu, std))
        kl1 = torch.stack(kl1)
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl1.shape)

    print(kl)
    print(kl1)
    assert ((kl - kl1).abs() < 1e-10).all()

    start = time.perf_counter()
    for i in range(times):
        kl = mu @ mu.T
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

    start = time.perf_counter()
    for i in range(times):
        kl1 = (mu[:, None, :] * mu[None, :, :]).sum(dim=-1)
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl1.shape)

    print(kl)
    print(kl1)
    assert ((kl - kl1).abs() < 1e-4).all()

    start = time.perf_counter()
    for i in range(times):
        kl = mu.T
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

    start = time.perf_counter()
    muT = mu.T
    for i in range(times):
        kl = mu @ muT
    torch.cuda.synchronize()
    end = time.perf_counter()
    print(end - start, kl.shape)

outputs:

0.000576397986151278 torch.Size([100, 100])
0.00018305005505681038 torch.Size([100, 100])
0.00018039101269096136 torch.Size([100, 100])
0.0067256499314680696 torch.Size([100, 100])
tensor([[     0.0000,  57080.6406,  66168.9922,  ...,  53879.9414,
         368704.2188,  79685.1484],
        [ 39291.7500,      0.0000, 107659.6328,  ..., 142458.5469,
         399184.1562,  47731.7344],
        [ 56672.7148,  88284.0703,      0.0000,  ...,  73333.1875,
         956130.2500,  51419.0156],
        ...,
        [ 79648.5938,  60400.5156,  77861.4609,  ...,      0.0000,
         243438.4062, 108050.6641],
        [151945.2969,  58202.5508,  33863.8906,  ...,  31096.5391,
              0.0000, 103794.0938],
        [ 50227.0508,  78881.3125,  76599.9531,  ...,  58455.1133,
         310935.7500,      0.0000]], device='cuda:0')
tensor([[     0.0000,  57080.6406,  66168.9922,  ...,  53879.9414,
         368704.2188,  79685.1484],
        [ 39291.7500,      0.0000, 107659.6328,  ..., 142458.5469,
         399184.1562,  47731.7344],
        [ 56672.7148,  88284.0703,      0.0000,  ...,  73333.1875,
         956130.2500,  51419.0156],
        ...,
        [ 79648.5938,  60400.5156,  77861.4609,  ...,      0.0000,
         243438.4062, 108050.6641],
        [151945.2969,  58202.5508,  33863.8906,  ...,  31096.5391,
              0.0000, 103794.0938],
        [ 50227.0508,  78881.3125,  76599.9531,  ...,  58455.1133,
         310935.7500,      0.0000]], device='cuda:0')
0.39876252599060535 torch.Size([100, 100])
0.00013514305464923382 torch.Size([100, 100])
tensor([[286.4470,  12.7784, -24.8256,  ..., -16.4116,  -4.7319,  22.7014],
        [ 12.7784, 247.9196,  -5.5962,  ..., -26.5139,  17.1346,   7.6810],
        [-24.8256,  -5.5962, 294.9920,  ..., -31.1086,   8.4197,  -2.6804],
        ...,
        [-16.4116, -26.5139, -31.1086,  ..., 252.4709,  39.1322, -24.4404],
        [ -4.7319,  17.1346,   8.4197,  ...,  39.1322, 250.0481, -23.8120],
        [ 22.7014,   7.6810,  -2.6804,  ..., -24.4404, -23.8120, 245.9375]],
       device='cuda:0')
tensor([[286.4470,  12.7784, -24.8256,  ..., -16.4116,  -4.7319,  22.7014],
        [ 12.7784, 247.9196,  -5.5962,  ..., -26.5139,  17.1346,   7.6810],
        [-24.8256,  -5.5962, 294.9920,  ..., -31.1086,   8.4197,  -2.6804],
        ...,
        [-16.4116, -26.5139, -31.1086,  ..., 252.4708,  39.1322, -24.4404],
        [ -4.7319,  17.1346,   8.4197,  ...,  39.1322, 250.0481, -23.8120],
        [ 22.7014,   7.6810,  -2.6804,  ..., -24.4404, -23.8120, 245.9375]],
       device='cuda:0')
1.8235063180327415e-05 torch.Size([256, 100])
3.278802614659071e-05 torch.Size([100, 100])

time consumed:
mu @ muT < (mu[:, None, :] * mu[None, :, :]).sum(dim=-1) < kl_div < mu @ mu.T

Yeah, but now you can trust your numbers more. :slight_smile:

1 Like