Face recognition model loss not decreasing

I wrote a script to do train a Siamese Network style model for face recognition on LFW dataset but the training loss doesnt decrease at all. Probably there’s a bug in my implementation. Could you please point it out.
Right now my code does:

  • Each epoch has 0.5M triplets all generated in an online way from data (since the exhaustive number of triplets is too big).
  • Triplet sampling method: We have a dictionary of {class_id: list of file paths with that class id}. We then create a list of classes which we can use for positive class (some classes have 1 image). At any iteration we randomly sample a positive class from this refined list and a negative class from the original list. We randomly sample 2 images from positive (as Anchor or A and Positive as P) and 1 from negative (Negative or N). A,P,N form our triplet.
  • Model used is ResNet with the ultimate (512,1000) softmax layer is replaced with (512,128) Dense layer (no activation). To avoid overfitting, only the last Dense and layer4 are kept trainable and rest are frozen.
  • During training we find triplets which are semi-hard in a batch (Loss between 0 and margin) and use only those to do backprop (they mention this in the FaceNet paper)
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os, glob
import numpy as np
from PIL import Image

image_size = 224
batch_size = 512
margin = 0.5
learning_rate = 1e-3
num_epochs = 1000

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 128, bias=False)

for param in model.parameters():
  param.requires_grad = False
for param in model.fc.parameters():
  param.requires_grad = True
for param in model.layer4.parameters():
  param.requires_grad = True

optimizer = optim.Adam(params=list(model.fc.parameters())+list(model.layer4.parameters()), lr=learning_rate, weight_decay=0.05)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model).to(device)
writer = SummaryWriter(log_dir="logs/")

class TripletDataset(Dataset):
  def __init__(self, rootdir, transform):
    self.rootdir = rootdir
    self.classes = os.listdir(self.rootdir)
    self.file_paths = {c: glob.glob(os.path.join(rootdir, c, "*.jpg")) for c in self.classes}
    self.positive_classes = [c for c in self.classes if len(self.file_paths[c])>=2]
    self.transform = transform

  def __getitem__(self, index=None):
    class_pos, class_neg = None, None
    while class_pos == class_neg:
      class_pos = np.random.choice(a=self.positive_classes, size=1)[0]
      class_neg = np.random.choice(a=self.classes, size=1)[0]

    fp_a, fp_p = np.random.choice(a=self.file_paths[class_pos], size=2, replace=False)
    fp_n = np.random.choice(a=self.file_paths[class_neg], size=1)[0]

    return {
        "fp_a": fp_a,
        "fp_p": fp_p,
        "fp_n": fp_n,
        "A": self.transform(Image.open(fp_a)),
        "P": self.transform(Image.open(fp_p)),
        "N": self.transform(Image.open(fp_n)),

  def __len__(self):
    return 500000

def triplet_loss(a, p, n, margin=margin):
    d_ap = (a-p).norm(p='fro', dim=1)
    d_an = (a-n).norm(p='fro', dim=1)
    loss = torch.clamp(d_ap-d_an+margin, min=0)
    return loss, d_ap.mean(), d_an.mean()

transform = transforms.Compose([
        transforms.Normalize([0.596, 0.436, 0.586], [0.2066, 0.240, 0.186])
train_dataset = TripletDataset("lfw", transform)
nw = 4 if torch.cuda.is_available() else 0
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True)

num_batches = len(train_dataloader)
running_loss = 0

for epoch in range(num_epochs):
    for batch_id, dictionary in enumerate(train_dataloader):
        a, p, n = dictionary["A"], dictionary["P"], dictionary["N"]
        a, p, n = a.to(device), p.to(device), n.to(device)
        emb_a, emb_p, emb_n = model(a), model(p), model(n)
        losses, d_ap, d_an = triplet_loss(a=emb_a, p=emb_p, n=emb_n)
        semi_hard_triplets = torch.where((losses>0) & (losses<margin))
        losses = losses[semi_hard_triplets]
        loss = losses.mean()
        running_loss += loss.item()
        print("Epoch {} Batch {}/{} Loss = {} Avg AP dist = {} Avg AN dist = {}".format(epoch, batch_id, num_batches, loss.item(), d_ap.item(), d_an.item()), flush=True)
        writer.add_scalar("Loss/Train", loss.item(), epoch*num_batches+batch_id)
        writer.add_scalars("AP_AN_Distances", {"AP": d_ap.item(), "AN": d_an.item()}, epoch*num_batches+batch_id)
    print("Epoch {} Avg Loss {}".format(epoch, running_loss/num_batches), flush=True)
    writer.add_scalar("Epoch_Loss", running_loss/num_batches, epoch)
    torch.save(model.state_dict(), "facenet_epoch_{}.pth".format(epoch))

Loss graphs: https://tensorboard.dev/experiment/8TgzPTjuRCOFkFV5lr5etQ/
Please let me know if you need some other information to help you help me.

Any help will be appreciated

Im facing the same problem, did you figure it out ?