Compute JS loss between gaussian distributions parameterized by mu and log_var

My network outputs 4 tensors with shape BxN, where B is batch size, N is latent dim.

  1. source_mu and source_log_var are mu and log_var of N gaussian distributions
  2. target_mu and target_log_var are mu and log_var of N gaussian distributions also.

I want to measure the JS divergence between those distributions.

Here is my code snippet, can someone help to confirm:

    def compute_js_loss(self, source_mu, source_log_var, target_mu, target_log_var):
        def get_prob(mu, log_var):
            dist = Normal(mu, torch.exp(0.5 * log_var))
            val = dist.sample()
            return dist.log_prob(val).exp()

        def kl_loss(p, q):
            return F.kl_div(p, q, reduction="batchmean", log_target=False)

        source_prob = get_prob(source_mu, source_log_var)
        target_prob = get_prob(target_mu, target_log_var)

        log_mean_prob = (0.5 * (source_prob + target_prob)).log()
        js_loss = 0.5 * (kl_loss(log_mean_prob, source_prob) + kl_loss(log_mean_prob, target_prob))
        return js_loss