Siamese Network for gender recognition

Hi everyone, recently I started to read about Siamese Nets and I wanted to try this type of model on a gender recognition task. My data is a .csv dataset containing ~3000 n-vectors of audio features (n=20). The loss function I’m using is the Contrastive Loss. Here is my model:

class ContrastiveLoss(torch.nn.Module):
  Contrastive loss function.
  Based on:
  def __init__(self, margin=1.0):
      super(ContrastiveLoss, self).__init__()
      self.margin = margin

  def forward(self, output1, output2, label):
      euclidean_distance = F.pairwise_distance(output1, output2)
      loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance.double(), 2) +
                                    (label) * torch.pow(torch.clamp(self.margin - euclidean_distance.double(), min=0.0), 2))

      return loss_contrastive

class SiameseMLP(nn.Module):
  def __init__(self):
    super(SiameseMLP, self).__init__()
    self.layers = nn.Sequential(
        nn.Linear(20, 256),
        nn.Linear(256, 256),
        nn.Linear(256, 256),
        nn.Linear(256, 2)

  def forward_once(self, x):
    x = x.view(-1, 20)
    x = self.layers(x)
    return x

  def forward(self, x_1, x_2):
    y_1 = self.forward_once(x_1)
    y_2 = self.forward_once(x_2)
    return y_1, y_2

I started to do some tests and for now I can’t reach a good value for the training loss, keeping it around ~0.33. When I calculate the dissimilarity between pairs, I get random values so I think the model is not learning completely. Do you have any suggest? Thanks.

I am not sure of the exact reason.

Below are some points that I could think of:

  1. You do not need tow square the euclidean distance, as pairwise distance already gives that.
  1. Check if your dataset is balanced/imbalanced.

  2. I see that you have used only 2-dimensions in the final layer. Try using more (64, 128, 256 etc.,)
    I am not sure if reducing 20 dimensions to 2 dimensions is not generalizing well.

  1. Try normalizing the final descriptor (F.normalize())

  2. Play-around with the margin parameter of siamese (contrastive) loss.
    Also, try triplet loss.