Costum loss function infinite/NaN loss

Hi, I created a costum hinge loss function for my project where I want to embed images and caption to the same vector space. The loss is calculated by creating a similarity matrix using the cosine similarity between all samples in the batch. So the diagonal holds the cosine similarity for the correct pair and the off-diagonal contains mismatched pairings. (FYI for the hinge loss the similarity for the correct pair must be higher than the similarity between mismatched pairs, with some margin)

# hinge loss using cosine similarity
def cosine_hinge_loss(embeddings_1, embeddings_2):

    # batch size
    batch_size = embeddings_1.size(0)
    # calculate the numerator
    numerator = torch.mm(embeddings_1, embeddings_2.t())
    # calculate the denominator
    denom1 = torch.sum(torch.pow(embeddings_1, 2), dim = 1)
    denom2 = torch.sum(torch.pow(embeddings_2, 2), dim = 1)
    
    denominator = torch.sqrt(torch.mm(denom1.expand(1,denom1.size(0)).t(), denom2.expand(1,denom2.size(0))))
    # similarity matrix
    sim = numerator/denominator
     # get the similarity of the correct image-caption pairs (the diagonal of the similarity matrix)
    matched = sim.diag()
    # get the average mismatch of the image with incorrect captions
    # sum the matrix along the corresponding axis, correct for including the correct pair, divide by batch size -1
    # (also to correct for including the correct pair)
    mismatch_1 = (sim.sum(dim = 0) - matched) / (batch_size - 1)
    # get the average mismatch of the captions with incorrect images
    mismatch_2 = (sim.sum(dim = 1) - matched) / (batch_size - 1)

    return torch.sum(nn.functional.relu(mismatch_1 - matched + 1) + nn.functional.relu(mismatch_2 - matched + 1))

The loss is the averaged similarity of each example with all the negative examples minus the similarity of the matched pair (plus margin). And this for two directions, image to caption and caption to image.

If I am correct the loss can never exceed 6 times the batch size:

 return torch.sum(nn.functional.relu(mismatch_1 - matched + 1) + nn.functional.relu(mismatch_2 - matched + 1))

In the extreme case that the average mismatch similarity is 1 and the matched similarity -1 we have a loss of 3 for each direction totaling 6 per example in the batch. Still I see that for a batch size of 128 my loss slowly goes up far exceeding what I understand is a hard upper limit and finally leading to NaN errors.

Does anyone have any idea how it is possible for the loss to exceed this limit. I could understand immediately receiving errors if there was an inf or 0 somewhere in the computation but a loss of more than 6 times the batch size means that the values of my cosine similarity exceed 1 or -1.

regards and thanks for you input.