Hello, I am trying to implement online hard/semi-hard triplet mining in pytorch. Below is my code: -
def mineHard(model, anchor, positive, negative, semiHard=False):
cnn = model
cnn.eval()
margin = 0.3
anchor, positive, negative = Variable(anchor).cuda(), Variable(positive).cuda(), Variable(negative).cuda()
output1, output2, output3 = cnn(anchor, positive, negative)
d_pos = F.pairwise_distance(output1, output2)
d_neg = F.pairwise_distance(output1, output3)
if semiHard:
pred1 = (d_pos - d_neg).cpu().data
pred2 = (d_pos + margin - d_neg).cpu().data
indices = numpy.logical_and((pred1 < 0), (pred2 > 0))
else:
pred = (d_pos - d_neg).cpu().data
indices = pred > 0
if indices.sum() == 0:
return None, None, None, False
x = torch.arange(0, d_pos.size()[0]).view(d_pos.size()[0], 1)
indices = x.type(torch.cuda.FloatTensor) * indices.type(torch.cuda.FloatTensor)
nonzero_indices = torch.nonzero(indices)
indices = indices[nonzero_indices[:, 0], :].view(nonzero_indices.size()[0]).type(torch.cuda.LongTensor)
anchor = torch.index_select(anchor.data, 0, indices)
positive = torch.index_select(positive.data, 0, indices)
negative = torch.index_select(negative.data, 0, indices)
return anchor, positive, negative, True
Now, the problem is that whenever I try to train my network on these extracted examples, in the first epoch the performance starts to deteriorate, and then after the second epoch none of my triplets in validation set is correctly classified.
Any help would be appreciated.