Calculate Jensen Shannon divergence

I would like to calculate JSD across N probabilities. This is a correct way to implement JSD ?

def jsd_loss(logits1, logits2, logits3, logits4):
    softmax1 = torch.softmax(logits1 + 1e-10, 1)
    softmax2 = torch.softmax(logits2 + 1e-10, 1)
    softmax3 = torch.softmax(logits3 + 1e-10, 1)
    softmax4 = torch.softmax(logits4 + 1e-10, 1)

    M = 0.25 * (softmax1 + softmax2 + softmax3 + softmax4)

    return 0.25 * (F.kl_div(M.log(), softmax1) +
                   F.kl_div(M.log(), softmax2) +
                   F.kl_div(M.log(),softmax3) +
                   F.kl_div(M.log(), softmax4))

I have added epsilon to logits to make sure that it does not go to inf

1 Like