Jensen Shannon Divergence


I am trying to implement Jensen Shannon Divergence (JSD) in Pytorch:
class JSD(torch.nn.Module)

def forward(self, P, Q):
kld = KLDivLoss().cuda()
M = 0.5 * (P + Q)
return 0.5 * (kld(P, M) + kld(Q, M))
When I run the above code, I am getting the below error:
AssertionError: nn criterions don’t compute the gradient w.r.t. targets - please mark these variables as volatile or not requiring gradients.

I guess KL expect second term to be not requiring gradients. But, in JSD M term also contains gradient requiring variable. Is there an easy way of dealing with this? Or, should I write my KLD function?


(I created this post by mistake, it has already been answered. But, I dont know how to remove this one.)


You should add


at the first line of init