Jensen Shannon Divergence

Hi,

I am trying to implement Jensen Shannon Divergence (JSD) in Pytorch:
class JSD(torch.nn.Module)

def forward(self, P, Q):
    kld = KLDivLoss().cuda()
    M = 0.5 * (P + Q)
    return 0.5 * (kld(P, M) + kld(Q, M))

When I run the above code, I am getting the below error:
AssertionError: nn criterions don’t compute the gradient w.r.t. targets - please mark these variables as volatile or not requiring gradients.

I guess KL expect second term to be not requiring gradients. But, in JSD M term also contains gradient requiring variable. Is there an easy way of dealing with this? Or, should I write my KLD function?

Thanks.

1 Like

I think the easiest way would be to write your own KLDiv function just like this one.

3 Likes

Hello,

Are you able to implement Jensen Shannon Divergence? If so kindly share the code. It will be really helpful.

Thanks

Hope this helps.

class JSD(nn.Module):
    
    def __init__(self):
        super(JSD, self).__init__()
    
    def forward(self, net_1_logits, net_2_logits):
        net_1_probs =  F.softmax(net_1_logits, dim=1)
        net_2_probs=  F.softmax(net_2_logits, dim=1)

        m = 0.5 * (net_1_probs + net_1_probs)
        loss = 0.0
        loss += F.kl_div(F.log_softmax(net_1_logits, dim=1), total_m, reduction="batchmean") 
        loss += F.kl_div(F.log_softmax(net_2_logits, dim=1), total_m, reduction="batchmean") 
     
        return (0.5 * loss)
3 Likes

So these two variables, net_1_probs , m are not used?
And where is the total_m

I had to modify the example to this:

Note the function is not designed to handle batches of inputs (matrix arguments), although it might.

def jenson_shannon_divergence(net_1_logits, net_2_logits):
    from torch.functional import F
    net_1_probs = F.softmax(net_1_logits, dim=0)
    net_2_probs = F.softmax(net_2_logits, dim=0)
    
    total_m = 0.5 * (net_1_probs + net_1_probs)
    
    loss = 0.0
    loss += F.kl_div(F.log_softmax(net_1_logits, dim=0), total_m, reduction="batchmean") 
    loss += F.kl_div(F.log_softmax(net_2_logits, dim=0), total_m, reduction="batchmean") 
    return (0.5 * loss)

Hey, Here’s a couple of things I thought worthy of mentioning here. First, both codes are only using:

total_m = 0.5 * (net_1_probs + net_1_probs)

The correct formulation is:

total_m = 0.5 * (net_1_probs + net_2_probs)

.

Also, based on @jeff-hykin and @Aryan_Asadian implementations, here’s mine. It is easier for me to use modules that are instances of nn.Module similar to @Aryan_Asadian’s implementation, because I can have forward/backward hooks.

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(p.log(), m) + self.kl(q.log(), m))

Note that I am taking softmax before passing p and q to my JSD instance. Also, note that this implementation works with matrices as well, since in the beginning I’m flattening both tensors.
Also, note that I’m passing log_target=True, which means the m should be in log-space. This makes the implementation slightly faster, because we’re computing the m.log() only once. Hope that this helps.

2 Likes

@Amin_Jun you did a wonderful job based on the previous answer. However, your implementation is still slightly problematic, which doesn’t guarantee the range of JS-divergence between 0 to 1. The KL-divergence function in pytorch is counterintuitive. KL(a,b) needs to be written in torch.nn.KLDivLoss()(b,a). So the correct one should be:

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(m, p.log()) + self.kl(m, q.log()))
3 Likes

@Renly_Hou You’re absolutely correct! Thanks for catching this. You saved me from many experiments.

Is this implementation correct if we assume that my inputs aren’t softmax ?
According to this link : torch.log_softmax, it’s recommended to directly use log_softmax instead of log(softmax…)

class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
        m = (0.5 * (p + q))
        return 0.5 * (self.kl(m, p) + self.kl(m, q))