My network outputs 4 tensors with shape BxN
, where B
is batch size, N
is latent dim.
-
source_mu
andsource_log_var
aremu
andlog_var
of N gaussian distributions -
target_mu
andtarget_log_var
aremu
andlog_var
of N gaussian distributions also.
I want to measure the JS divergence between those distributions.
Here is my code snippet, can someone help to confirm:
def compute_js_loss(self, source_mu, source_log_var, target_mu, target_log_var):
def get_prob(mu, log_var):
dist = Normal(mu, torch.exp(0.5 * log_var))
val = dist.sample()
return dist.log_prob(val).exp()
def kl_loss(p, q):
return F.kl_div(p, q, reduction="batchmean", log_target=False)
source_prob = get_prob(source_mu, source_log_var)
target_prob = get_prob(target_mu, target_log_var)
log_mean_prob = (0.5 * (source_prob + target_prob)).log()
js_loss = 0.5 * (kl_loss(log_mean_prob, source_prob) + kl_loss(log_mean_prob, target_prob))
return js_loss