I would like to calculate JSD across N probabilities. This is a correct way to implement JSD ?
def jsd_loss(logits1, logits2, logits3, logits4):
softmax1 = torch.softmax(logits1 + 1e-10, 1)
softmax2 = torch.softmax(logits2 + 1e-10, 1)
softmax3 = torch.softmax(logits3 + 1e-10, 1)
softmax4 = torch.softmax(logits4 + 1e-10, 1)
M = 0.25 * (softmax1 + softmax2 + softmax3 + softmax4)
return 0.25 * (F.kl_div(M.log(), softmax1) +
F.kl_div(M.log(), softmax2) +
F.kl_div(M.log(),softmax3) +
F.kl_div(M.log(), softmax4))
I have added epsilon to logits to make sure that it does not go to inf