Custom Rank Loss Function

I’m trying to define a custom loss function in PyTorch. Currently, it is as follows:

from scipy.spatial.distance import cdist
import numpy as np

class MRRLoss(nn.Module):
""" Mean Reciprocal Rank Loss """
    def __init__(self):
        super(MRRLoss, self).__init__()

    def forward(self, u, v, distance_metric="cosine"):
        #cosine distance between all pair of embedding in u and v batches.
        distances = cdist(u.detach().numpy(), v.detach().numpy(), metric=distance_metric)
        # by construction the diagonal contains the correct elements
        correct_elements = np.expand_dims(np.diag(distances), axis=-1)
        # number of elements ranked wrong.
        return np.sum(distances < correct_elements)

This loss function is used to train a model that generates embeddings for different objects, such as image and text. The objective is that the embedding of image i is as close as possible to the text t that describes it.

The loss has as input batches u and v, respecting image embeddings and text embeddings. It calculates the cosine similarity between all the x1 and x2 embedding pairs. Returns a value showing how many images were wrong ranked.

However, during training I am having the following error:

AttributeError: 'numpy.int64' object has no attribute 'backward'

How could I get around this kind of error?

Can you try this?:

class MRRLoss(nn.Module):
""" Mean Reciprocal Rank Loss """
    def __init__(self):
        super(MRRLoss, self).__init__()

    def forward(self, u, v):
        u=torch.reshape(u,(-1,))
        v=torch.reshape(v,(-1,))
        #cosine distance between all pair of embedding in u and v batches.
        cos = nn.CosineSimilarity(dim=0)
        distances=cos(u,v)
        # by construction the diagonal contains the correct elements
        correct_elements = torch.diag(cos(u,v),0).unsqueeze(-1)
        # number of elements ranked wrong.
        return torch.sum(distances < correct_elements)

From what I could understand, nn.CosineSimilarity loss computes the cosine similarity between an element i of batch u and another element i of batch v. What I’m looking for is an approach to compute the similarity matrix of all elements of u to all elements of v and define it as a PyTorch loss function. It was so easy in Tensorflow 2.

I guess if you want to compute for all dims,
you can simply reshape it with u=torch.reshape(u,(-1,))
and then calculate it

Thanks but now the error is:

RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

yes ,its looks good.
I have update the answer again with “torch.diagonal(cos(u,v),0).unsqueeze(0)”
Can you please try that
I should work now.try not to transform to numpy array as long as possible. try to make operations on torch tensor only.

<ipython-input-41-59f0a1e15b43> in forward(self, u, v)
     13         distances=cos(u,v)
     14         # by construction the diagonal contains the correct elements
---> 15         correct_elements = torch.diagonal(cos(u,v),0).unsqueeze(0)
     16         # number of elements ranked wrong.
     17         return torch.sum(distances < correct_elements)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
yes, Sorry I thought you were taking out diagonal elements.
torch.diag(cos(u,v),0).unsqueeze(-1))
this should work.

Thank you very much, but it’s not working.
You can verify it in a realistic example at Google Colab:
Custom Rank Loss Function

If help someone, I’ve implemented a better solution to this question inpired by N-Pair Loss published on NIPS 2016:

import torch
from torch import nn


class NPairsLoss(nn.Module):
    """
    The N-Pairs Loss.
    It measures the loss given predicted tensors x1, x2 both with shape [batch_size, hidden_size],
    and target tensor y which is the identity matrix with shape  [batch_size, batch_size].
    """

    def __init__(self):
        super(NPairsLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss()

    def similarities(self, x1, x2):
        """
        Calculates the cosine similarity matrix for every pair (i, j),
        where i is an embedding from x1 and j is another embedding from x2.

        :param x1: a tensors with shape [batch_size, hidden_size].
        :param x2: a tensors with shape [batch_size, hidden_size].
        :return: the cosine similarity matrix with shape [batch_size, batch_size].
        """
        x1 = x1 / torch.norm(x1, dim=1, keepdim=True)
        x2 = x2 / torch.norm(x2, p=2, dim=1, keepdim=True)
        return torch.matmul(x1, x2.t())

    def forward(self, predict, target):
        """
        Computes the N-Pairs Loss between the target and predictions.
        :param predict: the prediction of the model,
        Contains the batches x1 (image embeddings) and x2 (description embeddings).
        :param target: the identity matrix with shape  [batch_size, batch_size].
        :return: N-Pairs Loss value.
        """
        x1, x2 = predict
        predict = self.similarities(x1, x2)
        # by construction, the probability distribution must be concentrated 
        # on the diagonal of the similarities matrix.
        # so, Cross Entropy can be used to measure the loss.
        return self.ce(predict, target)

1 Like

A more recent option is fast-soft-sort from Google Research, and it supports pytorch.