Is this implementation correct if we assume that my inputs aren’t softmax ?
According to this link : torch.log_softmax, it’s recommended to directly use log_softmax instead of log(softmax…)
class JSD(nn.Module):
def __init__(self):
super(JSD, self).__init__()
self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)
def forward(self, p: torch.tensor, q: torch.tensor):
p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
m = (0.5 * (p + q))
return 0.5 * (self.kl(m, p) + self.kl(m, q))