KL-div between 2 multivariate Gaussians without for loop

Hi everyone,
I have 2 multivariate Gaussians and I want to compute KL-divergence between them. The shape of mu1, mu2, std1, std2 is (batch_size, 128). Currently I am computing this using for loop to cmpute this. Can this be done in a vectorized way?

def compute_kl_div(mu1, mu2, std1, std2):
    kl_sum = 0
    for i in range(batch_size):
        cov1 = torch.diag(std1[i])
        cov2 = torch.diag(std2[i])
        a = torch.logdet(cov1) - torch.logdet(cov2)
        cov2_inv = torch.matrix_power(cov2, -1)
        b = torch.trace(torch.mm(cov2_inv, cov1))
        c = torch.mm(torch.mm((mu2[i] - mu1[i]).unsqueeze(0), cov2_inv), (mu2[i] - mu1[i]).unsqueeze(1))
        kl_sum += (a - 64 + b + c)
    return 0.5*kl_sum

Every of the functions seems to be batcheable or have batcheable equivalents, so you could just convert it line by line.
But then, it seems that the computation by explicitly constructing the diagonal matrices is excessively inefficient – logdet or inverse are easy for diagonal matrices and hard for general ones – and I would recommend revisiting the maths here. (The GP people do a lot of these calculations, so GPyTorch or my ancient CandleGP will have plenty of examples, even if the latter doesn’t quite have what you want.

It may be convenient to use einsum for the summations.

Best regards

Thomas