to see torch.nn.kl_div but it is just a class that wraps a function called torch.kl_div
I can’t find that original function. I basically made my own function and it spits out different results from what Pytorch built-in is spitting so I’m wondering how does it look like.
For completition, here’s my code:
def my_kl(predicted, target): return -(target * t.log(predicted.clamp_min(1e-7))).sum(dim=1).mean() - \ -1*(target.clamp(min=1e-7) * t.log(target.clamp(min=1e-7))).sum(dim=1).mean()
Which I believe is spitting out correct results LOL