Compute JS loss between gaussian distributions parameterized by mu and log_var

My network outputs 4 tensors with shape `BxN`, where `B` is batch size, `N` is latent dim.

1. `source_mu` and `source_log_var` are `mu` and `log_var` of N gaussian distributions
2. `target_mu` and `target_log_var` are `mu` and `log_var` of N gaussian distributions also.

I want to measure the JS divergence between those distributions.

Here is my code snippet, can someone help to confirm:

``````    def compute_js_loss(self, source_mu, source_log_var, target_mu, target_log_var):
def get_prob(mu, log_var):
dist = Normal(mu, torch.exp(0.5 * log_var))
val = dist.sample()
return dist.log_prob(val).exp()

def kl_loss(p, q):
return F.kl_div(p, q, reduction="batchmean", log_target=False)

source_prob = get_prob(source_mu, source_log_var)
target_prob = get_prob(target_mu, target_log_var)

log_mean_prob = (0.5 * (source_prob + target_prob)).log()
js_loss = 0.5 * (kl_loss(log_mean_prob, source_prob) + kl_loss(log_mean_prob, target_prob))
return js_loss
``````