import torch.nn.functional as F
import torch.nn as nn
class Network(nn.Module):
def __init__(self, channel, classes=10):
super(Network, self).__init__()
self.fc = nn.Linear(channel, classes)
def forward(self, x):
x = self.fc(x)
return x
net = Network()
out = net(input)
# out.shape == [16,10]
Q = out
P = out
print((P * (P / Q).log()).sum()) # return 0
print(F.kl_div(Q, P, None, None, 'sum')) # return negative value.
I heard
(P * (P / Q).log()).sum()
is equal to
F.kl_div(Q.log(), P, None, None, 'sum')
.
With above my test code, (P * (P / Q).log()).sum()
returns 0 when Q and P are the same, but F.kl_div
returns negative values.
What`s the correct way to get KL-Divergence ?