I have a model that is comparing embeddings of its inputs to embeddings of a constant set of templates. I’m using a custom triplet margin loss as my model calculates the L2 distance internally. I’m getting an inconsistent error (it doesn’t happen every batch) where one of my dimensions gets off by 1.
def calc_batch_hard_triplet_loss(distances, gt_indices): """ Given L2 distances between the embedded anchor in each batch and each of the N templates and the indices indicating the single template match, calculates the batch hard triplet loss Args: distances: float tensor shape (batch, N) containing the L2 embedding distances gt_indices: int tensor shape (batch) containing the index of the matching template for each anchor Returns: The mean of the hardest triplet loss for each anchor """ batch_size = distances.shape n_templates = distances.shape template_indices = torch.arange(n_templates, dtype=torch.int64).repeat(batch_size, 1).cuda() non_matches = template_indices[torch.where(template_indices != gt_indices.view(batch_size,1))] non_matches = non_matches.view(batch_size, n_templates - 1) # error here after ~40 batches min_non_match_distances = torch.gather(distances, dim=1,index=non_matches).min(dim=1).values match_distances = torch.gather(distances, dim=1, index=gt_indices.view(batch_size,1)) diff = match_distances - min_non_match_distances + 1.0 loss = torch.where(diff > 0, diff, torch.zeros_like(diff)).mean() return loss
This loss seems to works for most batches then crashes with
File "/work/code/train.py", line 91, in calc_batch_hard_triplet_loss non_matches = non_matches.view(batch_size, n_templates - 1) RuntimeError: shape '[18, 485]' is invalid for input of size 8731
This makes no sense to me. How does the template index selection step end up off by 1? A bug in torch.arange?
(I’m running torch under the nvcr.io/nvidia/pytorch:22.08-py3 docker container)
Thanks for any help!