Triplet Loss Returning Zero

I was trying to test with tiplet loss and got something very weird. So lets say my

anchors, pos_embeddings, neg_embeddings are:
(tensor([[8.0368e+16, 4.2619e+16]]), tensor([[0.2196, 0.2067]]), tensor([[0.1873, 0.9207]]))


F.triplet_margin_loss(anchors.long(),pos_embeddings.long(),neg_embeddings.long(),margin=1,p=2,reduction='none') is returning 0

I know the anchors are extremely high, but I want to understand triplet loss behavior.

In your code snippet you are transforming the inputs to LongTensors, so pos_embeddings and neg_embeddings will be equal:

anchors, pos_embeddings, neg_embeddings = torch.tensor([[8.0368e+16, 4.2619e+16]]), torch.tensor([[0.2196, 0.2067]]), torch.tensor([[0.1873, 0.9207]])
> tensor([[0, 0]])
> tensor([[0, 0]])

Given that, the output of the loss function should then be margin, which is set to 1.
However, the intermediate distances would be large:

pos_embeddings = torch.tensor([[0, 0]])
neg_embeddings = pos_embeddings
dist_pos = torch.norm(anchor - pos_embeddings, p=2);
> tensor(1.2923e+10, grad_fn=<NormBackward1>)
dist_neg = torch.norm(anchor - neg_embeddings, p=2);
> tensor(1.2923e+10, grad_fn=<NormBackward1>)

The next operation would be the addition of the margin to the distance and the clamp operation:

output = torch.clamp_min(1. + dist_pos - dist_neg, 0)
> tensor(0., grad_fn=<ClampMinBackward0>)

Here you can see, that 1. + dist_pos - dist_neg underflows. The reason is the decimal step size, which is >1 for values >2**24 as described in this Wikipedia article.
A potential fix would be to use:

output = torch.clamp_min(1. + (dist_pos - dist_neg), 0)

or float64.

CC @tom what do you think about subtracting the distances first? Could this yield any other (unwanted) issues?

1 Like

It might work, but TBH neither does the cast to long make sense to me nor do the large numbers.
The original intention of the triple loss was to have feature vectors from three examples where a and p are of the same class and n is of a different one. I don’t think having anchors live on a completely different scale than p and n makes much sense for that.

1 Like