I want an efficient (vectorized) implementation instead of my current for-loop implementation.
I have a MxN matrix named Sim corresponding to the similarity scores of M anchors with N documents.
I have multiple negative (not matching the anchors) documents as well as multiple positive (matching) documents for each anchor, with the number of negative/positive documents not being the same for each anchor.
A triplet margin loss is defined on a triplet (anchor, positive, negative) and is computed as such:
H(a,p,n) = max(0, margin + Sim(a,n) - Sim(a,p))
The anchor corresponds to a row index, while the positive and negative are column indices.
For each anchor, my loss function is defined as the mean of the triplet losses with the f% least similar positives and the f% most similar positives, with f being an hyper-parameter (the same for all queries).
Say for the anchor i
we have 10 positives and 100 negatives, with f=0.5, we want to select the 5 minimum Sim(i,p), the 50 maximum S(i,n), and compute the mean of the triplet loss for all pairs of positive/negatives (the mean of a 5x50 matrix). Then for anchor j
we have 20 positives and 90 negatives, resulting in a 10x45 matrix of triplet losses to be averaged.
I have an implementation which computes the loss of each anchor in a for loop, but as you can imagine this is slow (with a small sim scores network, 50% of the training time is spent computing the loss).
If I constrain my problem to only have 1 positive document per anchor, I can make an efficient vectorized implementation (negligible time cost). I suppose doing an efficient implementation with a different (but still constant) number of positives would not be too hard.
How can I efficiently implement the loss with a varying number of positive documents for each anchor ?
Here is the code for the loop version:
def hardest_triplet_margin_loss(preds, targets, margin, hardest_fraction):
"""
preds is a vector of N scores.
targets is a vector of N floats being the
'positiveness' of the corresponding element in preds.
"""
targets = targets >= 0.5
positives = preds[targets]
negatives = preds[~targets]
hardest_positives = top_fraction(positives, hardest_fraction, False)\
.unsqueeze(1)
hardest_negatives = top_fraction(negatives, hardest_fraction, True)\
.unsqueeze(0).expand(hardest_positives.numel(), -1)
losses = (margin + hardest_negatives - hardest_positives).clamp(min=0)
loss = losses.mean()
return loss
def top_fraction(x, fraction, largest=True):
k = max(1, math.ceil(fraction * x.numel()))
out = x.topk(k, largest=largest)[0]
return out
num_anchors = 100
num_documents = 200
margin = 0.2
hardest_fraction = 0.5
# Matrix of similarity scores of anchors with documents (predicted by a NN)
sim = torch.rand(num_anchors, num_documents)
# Matrix of 'positiveness' of documents for each query.
# (>=0.5 means it's a positive) (given as input in the data)
targets = torch.rand(num_anchors, num_documents)
anchor_losses = torch.zeros(num_anchors)
for i in range(num_anchors):
anchor_losses[i] = hardest_triplet_margin_loss(
sim[i], targets[i], margin, hardest_fraction)
loss = anchor_losses.mean()
If it’s not possible to fully vectorize the computation, I suppose doing it for blocks of anchors having the same number of positives would be the next best thing, and I think I can handle that on my own.
Thanks for reading me!