Torch.gather does not propagate gradients?

Hello,

I’m trying to implement the Hellinger distance as a loss function. The semantic should be like CrossEntropyLoss, where the first input is the logits and the second are the expected target indices.
My implementation looks like this:

def hellinger_distance(P, target):
    Pp = torch.softmax(P, dim=1)
    # Hellinger distance between probability distributions P & Q is
    # = (1/sqrt(2))*sum((sqrt(p_i) - sqrt(q_i))^2)
    # If in this formula Q is a one-hot vector, this formula becomes much simpler:
    # Let t be the target index with a one in the Q distribution and every other index is zero.
    # For an index with q_i=0, the inner sum, (sqrt(p_i) - sqrt(0))^2 becomes just p_i.
    # For the index t with q_t=1, the inner sum expands to (sqrt(p_t) - sqrt(1))^2 = p_t - 2*sqrt(p_t) + 1.
    # So the entire calculation can be reduced to a sum over the entire P, including p_t.
    # And then pick out p_t from P and calculate sum(P) - 2*sqrt(p_t) + 1 as the final inner sum.
    # What remains, is just another square root and multiplication with the constant 1/sqrt(2).

    p_t = Pp.gather(1, target.view(P.shape[0], 1), sparse_grad=False).clamp(0.0001)
    if Pp.requires_grad:
        inner_sum = Pp.sum(dim=1)
        inner_sum -= p_t.view(P.shape[0]).sqrt()*2 - 1
    else:
        # if we don't need the gradients, we can simply assume the sum of all probs after softmax is 1
        inner_sum = 2 - p_t.view(P.shape[0]).sqrt()*2 # 1 - p_t.view(inner_sum.shape).sqrt()*2 + 1

    hellinger = torch.sqrt(inner_sum.clamp(0.0001))*0.7071067811865475 # 1/sqrt(2) = 0.7071067811865475
    # collect distances as batch
    collected = torch.mean(hellinger)
    return collected

x=torch.tensor([[0.0, 100.0, 0.0, 0.0], [100.0, 0.0, 0.0, 0.0], [0.0, 0.0, 100.0, 0.0],[0.0, 0.0, 0.0, 100.0]], dtype=torch.float32, requires_grad=True)
y=torch.tensor([0, 0, 0, 0]).long()
l=hellinger_distance(x, y)
l.backward()
print(l)
print(x)
print(x.grad)

which gives the output:

tensor(0.7480, grad_fn=<MeanBackward0>)
tensor([[  0., 100.,   0.,   0.],
        [100.,   0.,   0.,   0.],
        [  0.,   0., 100.,   0.],
        [  0.,   0.,   0., 100.]], requires_grad=True)
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

As you can see, the gradients are all zero. I traced the gradients through the calculation and everything seems fine until the gather operation. I also tried to detach the p_t tensor to break the backward calculation and it did not change the other values at all.
So it seems, that the gather operation does not propagate its input gradients received via p_t to the Pp tensor.

I’m using the version 2.3.0a0+40ec155e58.nv24.03 but I also quickly checked with 2.1.0+cpu and got the same result.

Hi @ozppupbg,

I don’t believe torch.gather has a gradient function as what would be the gradient of selecting a value from a Tensor?

Not towards the target/selector tensor, but towards the input tensor (Pp in this case).

So, after looking at the docs for torch.gather and following their example, torch.gather does indeed have a grad_fn,

t = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
loss = torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
loss.mean().backward()
t.grad
# return tensor([[0.5000, 0.0000],
#                [0.2500, 0.2500]])

The reason why you don’t have a grad_fn is because you’ve defined your outputs as dtype torch.Long, which doesn’t have a gradient by definition. In the example above, if you run it with dtype=torch.Long it will fail.

The problem is not the long tensor. My input tensors P and Pp are float tensors.
I figured it out after some experimentation with different numbers:
There appears to be a weird scaling issue between the clamping and the loss.
As long as the min and max of the input are not more than some value (e.g. 10) apart and I scale the loss by 1000, it works.