Is it a good idea to use cross entroy to implement InfoNCE?

Sorry if my question is too obscure. To provide a short description of what InfoNCE is, it’s a loss function used in SimCLR paper which is a contrastive training used in a joint embedding model. This means the output of the model will be a network that outputs an embedding per each input sample. If the model is trained correctly, the embeddings of similar inputs will be close to each other while the embedding of dissimilar inputs will be farther away from each other (usually based on cosine similarity). The job of InfoNCE is to implement this push and pull.

The important part of contrastive training is that each batch should come in pairs. These pairs will be of samples similar to each other while the rest of the pairs in the batch are considered dissimilar to the other pairs. Each sample in the pair will be fed to the network which will generate the embedding for that sample. After normalizing each outputs as part of the loss function, we’ll calculate the matrix multiplication of all the samples within the batch. This will give us a square matrix by the size of number of pairs in the batch and each cell of the this matrix will hold the cosine similarity between the first item of the pair in the column with the second sample of the pair in the row column of that cell. Of course, if the network is trained perfectly, this matrix will be an identity matrix.

Now, my question is, if I have the cosine similarity matrix for the batch and I want to calculate a loss value to push the matrix to identity matrix, do you think the following implementation should work?

def info_nce_loss(out1, out2):
    # Normalizing the output vectors so their length is one
    out1 = torch.nn.functional.normalize(out1, p=2, dim=-1)
    out2 = torch.nn.functional.normalize(out2, p=2, dim=-1)
    # Calculating similarity scores
    scores = torch.mm(out1, out2.T)
    # Labels is a vector holding the diagonal indices of an identity matrix
    labels = torch.arange(scores.size(0)).to(scores.device)
    return torch.nn.functional.cross_entropy(scores, labels)

I’ve been using this implementation in my experiment. In some scenarios it works and in some scenarios it does not which makes me wonder whether I should use a different implementation for InfoNCE.

Any suggestion is very much appreciated.

P.S. If anyone’s interested, this is the SimCLR paper