Jensen Shannon Divergence

Hi,

I am trying to implement Jensen Shannon Divergence (JSD) in Pytorch:
class JSD(torch.nn.Module)

def forward(self, P, Q):
kld = KLDivLoss().cuda()
M = 0.5 * (P + Q)
return 0.5 * (kld(P, M) + kld(Q, M))
When I run the above code, I am getting the below error:
AssertionError: nn criterions don’t compute the gradient w.r.t. targets - please mark these variables as volatile or not requiring gradients.

I guess KL expect second term to be not requiring gradients. But, in JSD M term also contains gradient requiring variable. Is there an easy way of dealing with this? Or, should I write my KLD function?

Thanks.

(I created this post by mistake, it has already been answered. But, I dont know how to remove this one.)

2 Likes

You should add

super().__init__()

at the first line of init