Jensen Shannon Divergence


I am trying to implement Jensen Shannon Divergence (JSD) in Pytorch:
class JSD(torch.nn.Module)

def forward(self, P, Q):
    kld = KLDivLoss().cuda()
    M = 0.5 * (P + Q)
    return 0.5 * (kld(P, M) + kld(Q, M))

When I run the above code, I am getting the below error:
AssertionError: nn criterions don’t compute the gradient w.r.t. targets - please mark these variables as volatile or not requiring gradients.

I guess KL expect second term to be not requiring gradients. But, in JSD M term also contains gradient requiring variable. Is there an easy way of dealing with this? Or, should I write my KLD function?


I think the easiest way would be to write your own KLDiv function just like this one.



Are you able to implement Jensen Shannon Divergence? If so kindly share the code. It will be really helpful.


Hope this helps.

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
    def forward(self, net_1_logits, net_2_logits):
        net_1_probs =  F.softmax(net_1_logits, dim=1)
        net_2_probs=  F.softmax(net_2_logits, dim=1)

        m = 0.5 * (net_1_probs + net_1_probs)
        loss = 0.0
        loss += F.kl_div(F.log_softmax(net_1_logits, dim=1), total_m, reduction="batchmean") 
        loss += F.kl_div(F.log_softmax(net_2_logits, dim=1), total_m, reduction="batchmean") 
        return (0.5 * loss)

So these two variables, net_1_probs , m are not used?
And where is the total_m

I had to modify the example to this:

Note the function is not designed to handle batches of inputs (matrix arguments), although it might.

def jenson_shannon_divergence(net_1_logits, net_2_logits):
    from torch.functional import F
    net_1_probs = F.softmax(net_1_logits, dim=0)
    net_2_probs = F.softmax(net_2_logits, dim=0)
    total_m = 0.5 * (net_1_probs + net_1_probs)
    loss = 0.0
    loss += F.kl_div(F.log_softmax(net_1_logits, dim=0), total_m, reduction="batchmean") 
    loss += F.kl_div(F.log_softmax(net_2_logits, dim=0), total_m, reduction="batchmean") 
    return (0.5 * loss)

Hey, Here’s a couple of things I thought worthy of mentioning here. First, both codes are only using:

total_m = 0.5 * (net_1_probs + net_1_probs)

The correct formulation is:

total_m = 0.5 * (net_1_probs + net_2_probs)


Also, based on @jeff-hykin and @Aryan_Asadian implementations, here’s mine. It is easier for me to use modules that are instances of nn.Module similar to @Aryan_Asadian’s implementation, because I can have forward/backward hooks.

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(p.log(), m) + self.kl(q.log(), m))

Note that I am taking softmax before passing p and q to my JSD instance. Also, note that this implementation works with matrices as well, since in the beginning I’m flattening both tensors.
Also, note that I’m passing log_target=True, which means the m should be in log-space. This makes the implementation slightly faster, because we’re computing the m.log() only once. Hope that this helps.


@Amin_Jun you did a wonderful job based on the previous answer. However, your implementation is still slightly problematic, which doesn’t guarantee the range of JS-divergence between 0 to 1. The KL-divergence function in pytorch is counterintuitive. KL(a,b) needs to be written in torch.nn.KLDivLoss()(b,a). So the correct one should be:

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))

@Renly_Hou You’re absolutely correct! Thanks for catching this. You saved me from many experiments.

Is this implementation correct if we assume that my inputs aren’t softmax ?
According to this link : torch.log_softmax, it’s recommended to directly use log_softmax instead of log(softmax…)

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
        m = (0.5 * (p + q))
        return 0.5 * (self.kl(m, p) + self.kl(m, q))

What are the inputs for the variables p and q? Do I have to send a data point from two datasets as a tensors or a batch of data from two datasets as tensors?