Currently, I’ve implemented the following solution inpired by N-Pair Loss published from NIPS 2016:
import torch
from torch import nn
from matplotlib import pyplot as plt
import seaborn as sn
class NPairsLoss(nn.Module):
"""
The N-Pairs Loss.
It measures the loss given predicted tensors x1, x2 both with shape [batch_size, hidden_size],
and target tensor y which is the identity matrix with shape [batch_size, batch_size].
"""
def __init__(self):
super(NPairsLoss, self).__init__()
self.ce = nn.CrossEntropyLoss()
def show(self, similarity_scores):
sn.heatmap(similarity_scores.detach().numpy(), annot=True, annot_kws={'size': 7}, vmin=-1.0, vmax=1.0)
plt.show()
def similarities(self, x1, x2):
"""
Calculates the cosine similarity matrix for every pair (i, j),
where i is an embedding from x1 and j is another embedding from x2.
:param x1: a tensors with shape [batch_size, hidden_size].
:param x2: a tensors with shape [batch_size, hidden_size].
:return: the cosine similarity matrix with shape [batch_size, batch_size].
"""
x1 = x1 / torch.norm(x1, dim=1, keepdim=True)
x2 = x2 / torch.norm(x2, p=2, dim=1, keepdim=True)
return torch.matmul(x1, x2.t())
def forward(self, predict, target):
"""
Computes the N-Pairs Loss between the target and predictions.
:param predict: the prediction of the model,
Contains the batches x1 (image embeddings) and x2 (description embeddings).
:param target: the identity matrix with shape [batch_size, batch_size].
:return: N-Pairs Loss value.
"""
x1, x2 = predict
predict = self.similarities(x1, x2)
self.show(predict)
# by construction the probability distribution must be concentrated on the diagonal of the similarities matrix.
# so, Cross Entropy can be used to measure the loss.
return self.ce(predict, target)
However, with this loss, the model ends up converging to a scenario where all dense vectors are equal to each other. Which can be seen by executing the following code snippet:
batch_size=7
hidden_size=768
def m_model(scenario=0):
if scenario == 0: # all equal all
p1 = torch.ones((batch_size, hidden_size))
p2 = p1
elif scenario == 1: # all different all
p1 = torch.ones((batch_size, hidden_size))
p2 = -1*p1
else: # desired case
p1 = torch.rand((batch_size, hidden_size))
p2=p1
return p1, p2
predict = m_model(scenario=0)
target = torch.arange(batch_size)
loss = NPairsLoss(1)
print("Loss:", loss(predict, target))
# Loss: tensor(1.9459), using scenario=0
# Loss: tensor(1.9459), using scenario=1
# Loss: tensor(1.7364), using scenario=2
Any suggestions on how to penalize these scenarios where the similarity matrix has all the same values?