# Jensen Shannon Divergence

Hi,

I am trying to implement Jensen Shannon Divergence (JSD) in Pytorch:
class JSD(torch.nn.Module)

``````def forward(self, P, Q):
kld = KLDivLoss().cuda()
M = 0.5 * (P + Q)
return 0.5 * (kld(P, M) + kld(Q, M))
``````

When I run the above code, I am getting the below error:
AssertionError: nn criterions don’t compute the gradient w.r.t. targets - please mark these variables as volatile or not requiring gradients.

I guess KL expect second term to be not requiring gradients. But, in JSD M term also contains gradient requiring variable. Is there an easy way of dealing with this? Or, should I write my KLD function?

Thanks.

1 Like

I think the easiest way would be to write your own KLDiv function just like this one.

3 Likes

Hello,

Are you able to implement Jensen Shannon Divergence? If so kindly share the code. It will be really helpful.

Thanks

Hope this helps.

``````class JSD(nn.Module):

def __init__(self):
super(JSD, self).__init__()

def forward(self, net_1_logits, net_2_logits):
net_1_probs =  F.softmax(net_1_logits, dim=1)
net_2_probs=  F.softmax(net_2_logits, dim=1)

m = 0.5 * (net_1_probs + net_1_probs)
loss = 0.0
loss += F.kl_div(F.log_softmax(net_1_logits, dim=1), total_m, reduction="batchmean")
loss += F.kl_div(F.log_softmax(net_2_logits, dim=1), total_m, reduction="batchmean")

return (0.5 * loss)
``````
3 Likes

So these two variables， net_1_probs ， m are not used？
And where is the total_m

I had to modify the example to this:

Note the function is not designed to handle batches of inputs (matrix arguments), although it might.

``````def jenson_shannon_divergence(net_1_logits, net_2_logits):
from torch.functional import F
net_1_probs = F.softmax(net_1_logits, dim=0)
net_2_probs = F.softmax(net_2_logits, dim=0)

total_m = 0.5 * (net_1_probs + net_1_probs)

loss = 0.0
loss += F.kl_div(F.log_softmax(net_1_logits, dim=0), total_m, reduction="batchmean")
loss += F.kl_div(F.log_softmax(net_2_logits, dim=0), total_m, reduction="batchmean")
return (0.5 * loss)
``````

Hey, Here’s a couple of things I thought worthy of mentioning here. First, both codes are only using:

``````total_m = 0.5 * (net_1_probs + net_1_probs)
``````

The correct formulation is:

``````total_m = 0.5 * (net_1_probs + net_2_probs)
``````

.

Also, based on @jeff-hykin and @Aryan_Asadian implementations, here’s mine. It is easier for me to use modules that are instances of `nn.Module` similar to @Aryan_Asadian’s implementation, because I can have forward/backward hooks.

``````class JSD(nn.Module):
def __init__(self):
super(JSD, self).__init__()
self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

def forward(self, p: torch.tensor, q: torch.tensor):
p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
m = (0.5 * (p + q)).log()
return 0.5 * (self.kl(p.log(), m) + self.kl(q.log(), m))
``````

Note that I am taking `softmax` before passing `p` and `q` to my `JSD` instance. Also, note that this implementation works with matrices as well, since in the beginning I’m flattening both tensors.
Also, note that I’m passing `log_target=True`, which means the `m` should be in `log-space`. This makes the implementation slightly faster, because we’re computing the `m.log()` only once. Hope that this helps.

3 Likes

@Amin_Jun you did a wonderful job based on the previous answer. However, your implementation is still slightly problematic, which doesn’t guarantee the range of JS-divergence between 0 to 1. The KL-divergence function in pytorch is counterintuitive. KL(a,b) needs to be written in torch.nn.KLDivLoss()(b,a). So the correct one should be:

``````class JSD(nn.Module):
def __init__(self):
super(JSD, self).__init__()
self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

def forward(self, p: torch.tensor, q: torch.tensor):
p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
m = (0.5 * (p + q)).log()
return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))
``````
4 Likes

@Renly_Hou You’re absolutely correct! Thanks for catching this. You saved me from many experiments.

Is this implementation correct if we assume that my inputs aren’t `softmax` ?
According to this link : torch.log_softmax, it’s recommended to directly use log_softmax instead of log(softmax…)

``````class JSD(nn.Module):
def __init__(self):
super(JSD, self).__init__()
self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

def forward(self, p: torch.tensor, q: torch.tensor):
p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
m = (0.5 * (p + q))
return 0.5 * (self.kl(m, p) + self.kl(m, q))
``````

What are the inputs for the variables p and q? Do I have to send a data point from two datasets as a tensors or a batch of data from two datasets as tensors?