Hi everyone,
I have 2 multivariate Gaussians and I want to compute KL-divergence between them. The shape of mu1, mu2, std1, std2
is (batch_size, 128)
. Currently I am computing this using for loop to cmpute this. Can this be done in a vectorized way?
def compute_kl_div(mu1, mu2, std1, std2):
kl_sum = 0
for i in range(batch_size):
cov1 = torch.diag(std1[i])
cov2 = torch.diag(std2[i])
a = torch.logdet(cov1) - torch.logdet(cov2)
cov2_inv = torch.matrix_power(cov2, -1)
b = torch.trace(torch.mm(cov2_inv, cov1))
c = torch.mm(torch.mm((mu2[i] - mu1[i]).unsqueeze(0), cov2_inv), (mu2[i] - mu1[i]).unsqueeze(1))
kl_sum += (a - 64 + b + c)
return 0.5*kl_sum