Please help me check whether the tripletloss is correct.
This is the code:
label=np.array([0.,1.,2.,3.,0.,1.])
label=torch.from_numpy(label)
label=label.reshape((6,1))
N=6
is_pos = label.expand(N, N).eq(label.expand(N, N).t())
is_neg = label.expand(N, N).ne(label.expand(N, N).t())
dist=comput_dist(label,label)
list_ap=[]
list_an = []
for i in range(N):
# Find the maximum distance between the ith label and the positive sample
list_ap.append(dist[i][is_pos[i]].max().unsqueeze(0))
#Find the minimum distance between the ith label and its negative sample
list_an.append(dist[i][is_neg[i]].min().unsqueeze(0))
dist_ap = torch.cat(list_ap)
dist_an = torch.cat(list_an)
y = torch.ones_like(dist_an)
loss =torch.nn.MarginRankingLoss()(dist_an,dist_ap,y)
print(loss)