Odd behavior in a custom triplet loss function

Hello,

I have a model that is comparing embeddings of its inputs to embeddings of a constant set of templates. I’m using a custom triplet margin loss as my model calculates the L2 distance internally. I’m getting an inconsistent error (it doesn’t happen every batch) where one of my dimensions gets off by 1.

def calc_batch_hard_triplet_loss(distances, gt_indices):
    """
    Given L2 distances between the embedded anchor in each batch and each of the N templates and the indices indicating the single template match, calculates the batch hard triplet loss
    Args:
        distances: float tensor shape (batch, N) containing the L2 embedding distances
        gt_indices: int tensor shape (batch) containing the index of the matching template for each anchor
    Returns:
        The mean of the hardest triplet loss for each anchor
    """
    batch_size = distances.shape[0]
    n_templates = distances.shape[1]
    template_indices = torch.arange(n_templates, dtype=torch.int64).repeat(batch_size, 1).cuda()
    non_matches = template_indices[torch.where(template_indices != gt_indices.view(batch_size,1))]
    non_matches = non_matches.view(batch_size, n_templates - 1) # error here after ~40 batches
    min_non_match_distances = torch.gather(distances, dim=1,index=non_matches).min(dim=1).values
    match_distances = torch.gather(distances, dim=1, index=gt_indices.view(batch_size,1))
    diff = match_distances - min_non_match_distances + 1.0
    loss = torch.where(diff > 0, diff, torch.zeros_like(diff)).mean()
    return loss

This loss seems to works for most batches then crashes with

 File "/work/code/train.py", line 91, in calc_batch_hard_triplet_loss
    non_matches = non_matches.view(batch_size, n_templates - 1)
RuntimeError: shape '[18, 485]' is invalid for input of size 8731

This makes no sense to me. How does the template index selection step end up off by 1? A bug in torch.arange?

(I’m running torch under the nvcr.io/nvidia/pytorch:22.08-py3 docker container)

Thanks for any help!

It appears there’s an occasional error in torch.arange(). I appear to have solved my error by calling arange once at the start of training then copying the array each time, rather than rebuilding it.

The fix wasn’t permanent. :frowning: It just took longer to happen. Ideas are still very welcome.

Hey, looks like that function doesn’t have a problem. I just have some bad data entries. :roll_eyes: