How can I vectorize this for-loop loss function?

I want an efficient (vectorized) implementation instead of my current for-loop implementation.

I have a MxN matrix named Sim corresponding to the similarity scores of M anchors with N documents.
I have multiple negative (not matching the anchors) documents as well as multiple positive (matching) documents for each anchor, with the number of negative/positive documents not being the same for each anchor.

A triplet margin loss is defined on a triplet (anchor, positive, negative) and is computed as such:

H(a,p,n) = max(0, margin + Sim(a,n) - Sim(a,p))

The anchor corresponds to a row index, while the positive and negative are column indices.

For each anchor, my loss function is defined as the mean of the triplet losses with the f% least similar positives and the f% most similar positives, with f being an hyper-parameter (the same for all queries).

Say for the anchor i we have 10 positives and 100 negatives, with f=0.5, we want to select the 5 minimum Sim(i,p), the 50 maximum S(i,n), and compute the mean of the triplet loss for all pairs of positive/negatives (the mean of a 5x50 matrix). Then for anchor j we have 20 positives and 90 negatives, resulting in a 10x45 matrix of triplet losses to be averaged.

I have an implementation which computes the loss of each anchor in a for loop, but as you can imagine this is slow (with a small sim scores network, 50% of the training time is spent computing the loss).
If I constrain my problem to only have 1 positive document per anchor, I can make an efficient vectorized implementation (negligible time cost). I suppose doing an efficient implementation with a different (but still constant) number of positives would not be too hard.

How can I efficiently implement the loss with a varying number of positive documents for each anchor ?

Here is the code for the loop version:

def hardest_triplet_margin_loss(preds, targets, margin, hardest_fraction):
    """
        preds is a vector of N scores.
        targets is a vector of N floats being the 
	'positiveness' of the corresponding element in preds.
    """
    targets = targets >= 0.5
    positives = preds[targets]
    negatives = preds[~targets]

    hardest_positives = top_fraction(positives, hardest_fraction, False)\
            .unsqueeze(1)
    hardest_negatives = top_fraction(negatives, hardest_fraction, True)\
            .unsqueeze(0).expand(hardest_positives.numel(), -1) 

    losses = (margin + hardest_negatives - hardest_positives).clamp(min=0)
    loss = losses.mean()
    return loss

def top_fraction(x, fraction, largest=True):
    k = max(1, math.ceil(fraction * x.numel()))
    out = x.topk(k, largest=largest)[0]
    return out

num_anchors = 100
num_documents = 200

margin = 0.2
hardest_fraction = 0.5

# Matrix of similarity scores of anchors with documents (predicted by a NN)
sim = torch.rand(num_anchors, num_documents)

# Matrix of 'positiveness' of documents for each query. 
# (>=0.5 means it's a positive) (given as input in the data)
targets = torch.rand(num_anchors, num_documents)

anchor_losses = torch.zeros(num_anchors)
for i in range(num_anchors):
    anchor_losses[i] = hardest_triplet_margin_loss(
        sim[i], targets[i], margin, hardest_fraction) 

loss = anchor_losses.mean()

If it’s not possible to fully vectorize the computation, I suppose doing it for blocks of anchors having the same number of positives would be the next best thing, and I think I can handle that on my own.

Thanks for reading me!

Hi Jhuteau!

Easy peasy:

import math

import torch

print (torch.__version__)

torch.manual_seed (20212021)


# code with tensor operations

def hardest_triplet_margin_lossB (preds, targets, margin, hardest_fraction):
    # preds:             the similarity scores, of shape [num_anchors, num-documents]
    # targets:           the "positiveness" of the scores, of the same shape
    # margin:            the margin-loss margin
    # hardest_fraction:  fraction of weakest "positive" and "negative" scores to analyze
    
    n_doc = preds.shape[1]
    
    pos_mask = (targets >= 0.5).int()
    neg_mask = 1 - pos_mask
    max_preds = preds.max (dim = 1)[0]
    min_preds = preds.min (dim = 1)[0]
    
    pos_sort = pos_mask * preds  +  neg_mask * (1 + max_preds).unsqueeze (1)
    pos_sort = pos_sort.sort (dim = 1, descending = True)[0]
    pos_k = (hardest_fraction * pos_mask.sum (dim = 1)).ceil().maximum (torch.ones (1))
    pos_hard_mask = ((n_doc - pos_k).unsqueeze (1) <= torch.arange (n_doc)).int()
    
    neg_sort = neg_mask * preds  +  pos_mask * (min_preds - 1).unsqueeze (1)
    neg_sort = neg_sort.sort (dim = 1)[0]
    neg_k = (hardest_fraction * neg_mask.sum (dim = 1)).ceil().maximum (torch.ones (1))
    neg_hard_mask = ((n_doc - neg_k).unsqueeze (1) <= torch.arange (n_doc)).int()
    
    margin_mask = neg_hard_mask.unsqueeze (1) * pos_hard_mask.unsqueeze (2)
    margins = ((neg_sort.unsqueeze (1) - pos_sort.unsqueeze (2)) + margin).clamp (min = 0)
    
    losses = (margin_mask * margins).sum (dim = (1, 2)) / margin_mask.sum (dim = (1, 2))
    loss = losses.mean()
    return loss


# code with loop

def hardest_triplet_margin_loss(preds, targets, margin, hardest_fraction):
    """
        preds is a vector of N scores.
        targets is a vector of N floats being the 
	'positiveness' of the corresponding element in preds.
    """
    targets = targets >= 0.5
    positives = preds[targets]
    negatives = preds[~targets]

    hardest_positives = top_fraction(positives, hardest_fraction, False)\
            .unsqueeze(1)
    hardest_negatives = top_fraction(negatives, hardest_fraction, True)\
            .unsqueeze(0).expand(hardest_positives.numel(), -1) 

    losses = (margin + hardest_negatives - hardest_positives).clamp(min=0)
    loss = losses.mean()
    return loss

def top_fraction(x, fraction, largest=True):
    k = max(1, math.ceil(fraction * x.numel()))
    out = x.topk(k, largest=largest)[0]
    return out

num_anchors = 100
num_documents = 200

margin = 0.2
hardest_fraction = 0.5

# Matrix of similarity scores of anchors with documents (predicted by a NN)
sim = torch.rand(num_anchors, num_documents)

# Matrix of 'positiveness' of documents for each query. 
# (>=0.5 means it's a positive) (given as input in the data)
targets = torch.rand(num_anchors, num_documents)

anchor_losses = torch.zeros(num_anchors)
for i in range(num_anchors):
    anchor_losses[i] = hardest_triplet_margin_loss(
        sim[i], targets[i], margin, hardest_fraction) 

loss = anchor_losses.mean()

###

print ('loss =', loss)

lossB = hardest_triplet_margin_lossB (sim, targets, margin, hardest_fraction)
print ('lossB =', lossB)

print ('torch.allclose (lossB, loss) =', torch.allclose (lossB, loss))
print ('(lossB == loss) =', (lossB == loss))
1.7.1
loss = tensor(0.6941)
lossB = tensor(0.6941)
torch.allclose (lossB, loss) = True
(lossB == loss) = tensor(True)

Best.

K. Frank

3 Likes

I benchmarked different implementations on CPU, namely KFrank’s solution, chunked_hardest_fraction_triplet_margin_loss and hardest_triplet_margin_loss.
chunked splits the preds into groups having the same number of positives, while hardest only work if you want the absolute hardest triplet.

# Multiple positives, hardest fraction, chunk vectorized O(n^2) memory
def chunked_hardest_fraction_triplet_margin_loss(
        preds, targets, margin, hardest_fraction):
    targets = (targets >= 0.5)

    num_positives = targets.sum(dim=1)

    unique_num_positives = num_positives.unique()

    loss = torch.tensor(0.0)
    for i, single_num_positives in enumerate(unique_num_positives):
        group_indices = (num_positives == single_num_positives)\
            .nonzero(as_tuple=True)

        group_loss = constant_hardest_fraction_triplet_margin_loss(
                preds[group_indices], targets[group_indices], 
                margin, hardest_fraction)

        loss += group_loss * group_indices[0].size(0)
    loss /= preds.size(0)

    return loss

# Multiple (same # of) positives, hardest fraction, vectorized O(n^2) memory
def constant_hardest_fraction_triplet_margin_loss(
        preds, targets, margin, hardest_fraction):

    num_docs = preds.size(0)
    targets = (targets >= 0.5)

    positive_indices = targets.nonzero(as_tuple=True)
    negative_indices = (~targets).nonzero(as_tuple=True)

    positives = preds[positive_indices].view(num_docs, -1)
    negatives = preds[negative_indices].view(num_docs, -1)

    pos_k = max(1, math.ceil(hardest_fraction * positives.size(1)))
    neg_k = max(1, math.ceil(hardest_fraction * negatives.size(1)))

    hardest_positives = positives.topk(pos_k, dim=1, largest=False)[0]
    hardest_negatives = negatives.topk(neg_k, dim=1, largest=True)[0]

    hardest_positives = hardest_positives.unsqueeze(1)
    hardest_negatives = hardest_negatives.unsqueeze(2)

    triplet_losses = (hardest_negatives + margin - hardest_positives).clamp(min=0)
    anchor_losses = triplet_losses.mean(dim=1)

    loss = anchor_losses.mean() / margin
    return loss

# Multiple positives, hardest triplet only, vectorized with O(n^2) memory
def hardest_triplet_margin_loss(
        preds, targets, margin):
    max_pred = preds.max(dim=1)[0]
    min_pred = preds.min(dim=1)[0]

    pos_mask = (targets >= 0.5).int()
    neg_mask = 1 - pos_mask

    positives = preds * pos_mask + neg_mask * (max_pred + 1).unsqueeze(1)
    hardest_pos = positives.min(dim=1)[0]

    negatives = preds * neg_mask + pos_mask * (min_pred - 1).unsqueeze(1)
    hardest_neg = negatives.max(dim=1)[0]

    anchor_losses = (hardest_neg + margin - hardest_pos).clamp(min=0)
    loss = anchor_losses.mean() / margin
    return loss

naive is a python for-loop implementation, anchor the implementation in the original post of this thread (mine), and vectorized KFrank’s implementation. I normalized results by the fastest one. MP_HF is for multiple positives with a fraction of the hardest, and MP_H for multiple positives with only the hardest.
The tests were run on a 4-core CPU.
I expect KFrank’s implementation to fare better on a GPU compared to the chunked implementation (but it requires O(num_anchors^3) memory). It also depends on how much the preds can be grouped (in the worst case chunked==anchor, and anchor is slower than vectorized).

Starting MP_HF.
naive: 430.700
anchor: 9.194
chunked: 1.000
vectorized: 2.516

Starting MP_H.
naive: 1198.327
anchor: 48.262
chunked: 3.912
vectorized: 11.033
hardest: 1.000