Online Triplet Hard Mining

Hello, I am trying to implement online hard/semi-hard triplet mining in pytorch. Below is my code: -

def mineHard(model, anchor, positive, negative, semiHard=False):
    cnn = model
    cnn.eval()
    margin = 0.3

    anchor, positive, negative = Variable(anchor).cuda(), Variable(positive).cuda(), Variable(negative).cuda()
    output1, output2, output3 = cnn(anchor, positive, negative)
    
    d_pos = F.pairwise_distance(output1, output2)
    d_neg = F.pairwise_distance(output1, output3)
    if semiHard:
        pred1 = (d_pos - d_neg).cpu().data
        pred2 = (d_pos + margin - d_neg).cpu().data
        indices = numpy.logical_and((pred1 < 0), (pred2 > 0))
    else:
        pred = (d_pos - d_neg).cpu().data
        indices = pred > 0
    
    if indices.sum() == 0:
        return None, None, None, False

    x = torch.arange(0, d_pos.size()[0]).view(d_pos.size()[0], 1)
    indices = x.type(torch.cuda.FloatTensor) * indices.type(torch.cuda.FloatTensor)
    
    nonzero_indices = torch.nonzero(indices)
    indices = indices[nonzero_indices[:, 0], :].view(nonzero_indices.size()[0]).type(torch.cuda.LongTensor)
    
    anchor = torch.index_select(anchor.data, 0, indices)
    positive = torch.index_select(positive.data, 0, indices)
    negative = torch.index_select(negative.data, 0, indices)

    return anchor, positive, negative, True

Now, the problem is that whenever I try to train my network on these extracted examples, in the first epoch the performance starts to deteriorate, and then after the second epoch none of my triplets in validation set is correctly classified.
Any help would be appreciated.

Shouldn’t you do a L2 distance or similar distance based measure rather than a per element subtraction for selecting the indices?