I have a model that is comparing embeddings of its inputs to embeddings of a constant set of templates. I’m using a custom triplet margin loss as my model calculates the L2 distance internally. I’m getting an inconsistent error (it doesn’t happen every batch) where one of my dimensions gets off by 1.

```
def calc_batch_hard_triplet_loss(distances, gt_indices):
"""
Given L2 distances between the embedded anchor in each batch and each of the N templates and the indices indicating the single template match, calculates the batch hard triplet loss
Args:
distances: float tensor shape (batch, N) containing the L2 embedding distances
gt_indices: int tensor shape (batch) containing the index of the matching template for each anchor
Returns:
The mean of the hardest triplet loss for each anchor
"""
batch_size = distances.shape[0]
n_templates = distances.shape[1]
template_indices = torch.arange(n_templates, dtype=torch.int64).repeat(batch_size, 1).cuda()
non_matches = template_indices[torch.where(template_indices != gt_indices.view(batch_size,1))]
non_matches = non_matches.view(batch_size, n_templates - 1) # error here after ~40 batches
min_non_match_distances = torch.gather(distances, dim=1,index=non_matches).min(dim=1).values
match_distances = torch.gather(distances, dim=1, index=gt_indices.view(batch_size,1))
diff = match_distances - min_non_match_distances + 1.0
loss = torch.where(diff > 0, diff, torch.zeros_like(diff)).mean()
return loss
```

This loss seems to works for most batches then crashes with

```
File "/work/code/train.py", line 91, in calc_batch_hard_triplet_loss
non_matches = non_matches.view(batch_size, n_templates - 1)
RuntimeError: shape '[18, 485]' is invalid for input of size 8731
```

This makes no sense to me. How does the template index selection step end up off by 1? A bug in torch.arange?

(I’m running torch under the nvcr.io/nvidia/pytorch:22.08-py3 docker container)

