I am trying to implement the loss function I am unsure if it is correct. My loss seems to explode into infinity. I was wondering if someone was able to point out an error with my code?
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.5):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, projections_1, projections_2):
z_i = projections_1
z_j = projections_2
z_i_norm = F.normalize(z_i, dim=1)
z_j_norm = F.normalize(z_j, dim=1)
cosine_num = torch.matmul(z_i, z_j.T)
cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
cosine_similarity = cosine_num / cosine_denom
numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)
denominator = cosine_similarity
diagonal_indices = torch.arange(denominator.size(0))
denominator[diagonal_indices, diagonal_indices] = 0
denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
loss = -torch.log(numerator / denominator).sum()
return loss