Hello,
I’m trying to implement the Hellinger distance as a loss function. The semantic should be like CrossEntropyLoss, where the first input is the logits and the second are the expected target indices.
My implementation looks like this:
def hellinger_distance(P, target):
Pp = torch.softmax(P, dim=1)
# Hellinger distance between probability distributions P & Q is
# = (1/sqrt(2))*sum((sqrt(p_i) - sqrt(q_i))^2)
# If in this formula Q is a one-hot vector, this formula becomes much simpler:
# Let t be the target index with a one in the Q distribution and every other index is zero.
# For an index with q_i=0, the inner sum, (sqrt(p_i) - sqrt(0))^2 becomes just p_i.
# For the index t with q_t=1, the inner sum expands to (sqrt(p_t) - sqrt(1))^2 = p_t - 2*sqrt(p_t) + 1.
# So the entire calculation can be reduced to a sum over the entire P, including p_t.
# And then pick out p_t from P and calculate sum(P) - 2*sqrt(p_t) + 1 as the final inner sum.
# What remains, is just another square root and multiplication with the constant 1/sqrt(2).
p_t = Pp.gather(1, target.view(P.shape[0], 1), sparse_grad=False).clamp(0.0001)
if Pp.requires_grad:
inner_sum = Pp.sum(dim=1)
inner_sum -= p_t.view(P.shape[0]).sqrt()*2 - 1
else:
# if we don't need the gradients, we can simply assume the sum of all probs after softmax is 1
inner_sum = 2 - p_t.view(P.shape[0]).sqrt()*2 # 1 - p_t.view(inner_sum.shape).sqrt()*2 + 1
hellinger = torch.sqrt(inner_sum.clamp(0.0001))*0.7071067811865475 # 1/sqrt(2) = 0.7071067811865475
# collect distances as batch
collected = torch.mean(hellinger)
return collected
x=torch.tensor([[0.0, 100.0, 0.0, 0.0], [100.0, 0.0, 0.0, 0.0], [0.0, 0.0, 100.0, 0.0],[0.0, 0.0, 0.0, 100.0]], dtype=torch.float32, requires_grad=True)
y=torch.tensor([0, 0, 0, 0]).long()
l=hellinger_distance(x, y)
l.backward()
print(l)
print(x)
print(x.grad)
which gives the output:
tensor(0.7480, grad_fn=<MeanBackward0>)
tensor([[ 0., 100., 0., 0.],
[100., 0., 0., 0.],
[ 0., 0., 100., 0.],
[ 0., 0., 0., 100.]], requires_grad=True)
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
As you can see, the gradients are all zero. I traced the gradients through the calculation and everything seems fine until the gather operation. I also tried to detach the p_t
tensor to break the backward calculation and it did not change the other values at all.
So it seems, that the gather operation does not propagate its input gradients received via p_t
to the Pp
tensor.
I’m using the version 2.3.0a0+40ec155e58.nv24.03
but I also quickly checked with 2.1.0+cpu
and got the same result.