Variant of CosineEmbeddingLoss

Hi,

I’m looking for a variant of CosineEmbeddingLoss in which instead of using the cosine distance I would like to use the squared Euclidean distance.
Is it available in Pytorch?
Thank you :slight_smile:

PS: I tried to implement a very first implementation, but it is not vectorized(hence not computational efficient)

1 Like

Are you looking for MSELoss?

Not exactly; I would like a version of MSELoss that behaves according to the margin and tensor y of similarities, as described in CosineEmbeddingLoss(i.e. exactly the same behavior described at this link, but using mean squared error instead of cosine distance to compute distances).
Maybe I can do in this way:

import torch 
import torch.nn as nn
import torch.nn.functional as F

class MarginMSELoss(nn.Module):
    def __init__(self, margin, reduction='mean'):
        super(MarginMSELoss, self).__init__()
        self.margin = margin
        self.reduction = reduction

    def forward(self, x1, x2, y):
        # x1.shape == x2.shape == (batch_size, num_features)
        # y.shape == (batch_size,)

        dist = F.mse_loss(x1, x2, reduction='none').mean(dim=1)      # dist.shape == (batch_size,)
        
        mask = y == 1
        out1 = (1 - dist) * mask     # out1.shape == (batch_size,)
        
        mask = y != 1
        out2 = (torch.max(torch.zeros_like(dist), dist - self.margin)) * mask
        
        out = out1 + out2

        if self.reduction == 'none':
            return out
        elif self.reduction == 'sum':
            return out.sum()
        elif self.reduction == 'mean':
            return out.mean()
        else:
            raise Exception('Reduction ' + self.reduction + ' is not supported')
>>> loss = MarginMSELoss(1)
>>> x1 = torch.randn(3, 5, requires_grad=True)
>>> x2 = torch.randn(3, 5, requires_grad=True)
>>> y = torch.randint(0, 2, (3,))
>>> output = loss(x1, x2, y)
>>> output.backward()

Thank you for your patience :slight_smile:

Right, sorry, I think HingeEmbeddingLoss, but L1-based rather than L2.

1 Like