Distance function in TripletMarginloss

Hi.
Is the distance function implemented in TripletMarginLoss with p=2 as same as that implemented by torch.dist(p=2)?

1 Like

nn.TripletMarginLoss would use torch.pairwise_distance internally as seen here. While pairwise_distance should yield the same result as torch.dist for 1-dimensional tensors, it should not yield the same outputs for multi dimensional tensors (of course the actual TripletMarginLoss would also use the specified margin):

# single dim
x1 = torch.randn(10)
x2 = torch.randn(10)

# scalar outputs
out1 = torch.pairwise_distance(x1, x2, p=2.0)
out2 = torch.dist(x1, x2, p=2.0)

print((out1 - out2).abs().max())
# tensor(4.7684e-07)

# multi-dim
x1 = torch.randn(10, 2)
x2 = torch.randn(10, 2)

out1 = torch.pairwise_distance(x1, x2, p=2.0)
out2 = torch.dist(x1, x2, p=2.0)

out1_manual = torch.sqrt(torch.sum((x1 - x2)**2, dim=1))
print((out1 - out1_manual).abs().max())
# tensor(1.4305e-06)

out2_manual = torch.sqrt(torch.sum((x1 - x2)**2))
print((out2 - out2_manual).abs().max())
# tensor(0.)

Thanks @ptrblck for providing explanation along with example.
As a result, is it true if I say, if we use nn.TripletMarginLoss during training, we shouldn’t use torch.dist during the test time? because during the training phase tensors are multi-dimensional(i.e.,Anchor.shape==Positive.shape==Negative.shape== [batch,N] where batch is the batch size and N is the length of the output vector).
If the above is true? what kind of distance metric I should use during the test time when I have just one Anchor, one Positive and one Negative instance?
Thanks.

@ptrblck I have one more question if it is possible.
I created an object from nn.TripletMarginLoss and in each iteration, I push three embedings as anchor, positive, and negative into this function. Then, I consider the output as the loss for my model. Is it true? or this nn.TripletMarginLoss is just used to obtain the hard triplet?

@ptrblck I tested the below code. It seems that all implementation follow the same function which is different from your example where torch.dist is different from torch.pairwise_distance.

a = torch.rand(5,4)
p = torch.rand(5,4)
n = torch.rand(5,4)


triplet_loss = torch.nn.TripletMarginLoss(margin=1,p=2,reduction='none')

triplet = triplet_loss(a,p,n)
print('triplet ',triplet)
print('###')
ap = torch.pairwise_distance(a,p,p=2)
an = torch.pairwise_distance(a,n,p=2)
print(torch.nn.ReLU()(ap-an+1))
print('@@@@@@@@@')
for i in range(a.shape[0]):
    ap = torch.dist(a[i,:],p[i,:],p=2)
    an = torch.dist(a[i,:],n[i,:],p=2)
    print(torch.nn.ReLU()(ap-an+1))
print('!!!!')


ap = (a-p).pow(2).sum(1).sqrt()
an = (a-n).pow(2).sum(1).sqrt()
print(torch.nn.ReLU()(ap-an+1))

Yes, that’s correct.

Note the shapes of my input as well as the output and compare it to the manual computation in out1_manual and out2_manual.
If you avoid summing in all dimensions (out2_manual) but instead apply torch.dist in a loop, the formulas should be equal.