Contrastive sentence embeddings loss implementation

I’m looking to implement the supervised SimCSE loss using positive and negative pairs.

image

Is there a way to vectorize the naive implementation below?

import torch
import torch.nn.functional as F


batch_size = 4
feature_dim = 1024

h = torch.randn(batch_size, 3, feature_dim)

temp = 1

num = torch.exp(
    F.cosine_similarity(h[:, 0, :], h[:, 1, :], dim=1) / temp
    )

denom = torch.empty_like(num, device=num.device)
for j in range(batch_size):
    denomjj = 0
    for jj in range(batch_size):
        denomjj += torch.exp(F.cosine_similarity(h[j, 0, :], h[jj, 1, :], dim=0) / temp)
        denomjj += torch.exp(F.cosine_similarity(h[j, 0, :], h[jj, 2, :], dim=0) / temp)
    denom[j] = denomjj

loss = -torch.log(num / denom)