Metric Learning For Classification - How good is it?

Hello guys!

I am trying to figure out how and if I could use metric learning for the task of image classification. The approach I am trying out is to cluster objects of same class together while maximizing the distance between clusters of different class.

To give a concrete example, I am trying to classify if a dress is a summer dress or a winter dress; the former is more free, flowy and loose-fitting while the latter is a body-hugging, thick fabric whole body covering kind of dress; the classes are kinda self-explanatory that way…

For the model to do metric learning, I have used a resnet50 head followed by average pooling and a linear layer to reduce the embedding dimension from 2048 to 512 respectively.

In the dataset, I have sampled 2 classes with a batch size of 32 i.e. 16 images per class for comparison.

I have used Multi-Similarity Loss, with hyper-parameters alpha = 2 (positive pairs weightage), beta = 50 (negative pairs weightage) and base = 0.5 as mentioned in the MS Loss paper. Any suggestions on tuning these are also welcome.

After training for around 20 epochs with using a differential learning rate strategy and tuning lr after almost every 5 epochs, the model that I trained gives the following embedding plot (512 feature vector embeddings)

However, when I do the same task with a traditional resnet50, I get the following embedding plot

What I don’t understand is that the MSLoss, which is specifically designed to make sure the embeddings of similar class should get closer and that across classes the embeddings should be far apart, fails to cluster the images well whereas a simple classification cross-entropy loss performs better.

Could someone from the community shed light on this topic; any insights are appreciated.

Thanks & Regards,