Triplet loss stuck at margin alpha value

Hi everyone

I’m struggling with the triplet loss convergence. I’m trying to do a face verification (1:1 problem) with a minimum computer calculation (since I don’t have GPU).
So I’m using the facenet-pytorch model InceptionResnetV1 pretrained with vggface2 (casia-webface gives the same results).
I created a dataset with anchors, positives and negatives samples and I unfreezed the last linear layer (~900k parameters).
The training beginning well but the loss is completely stuck… After some investigations, it seems the loss is stuck at the value alpha (the margin of the Pytorch Triplet Loss)…
If we look at the loss equation, it says

max[ L2norm(f(A)-f(P)) - L2norm(f(A)-f(N)) + alpha, 0 ]

So it seems the condition below is always verified, which is weird…

  • L2norm(f(A)-f(P)) = L2norm(f(A)-f(N))

I’ve tried many combinations (changing lr, release more layers…) but the loss still has the same behavior…
I’m working on Google Collab with around 8000 images (4k positives/anchors and 4k negatives)

Could someone help me on this please ?

class SiameseDataset2(Dataset):
    def __init__(self, list_PIL_positive, list_PIL_negative, val_stride = 0, isValSet_bool = None, Transform=False, Normalize=False, mean=None, std=None):
        
        self.Transform = Transform
        self.Normalize = Normalize
        self.mean = mean
        self.std = std

        self.A = random.sample(list_PIL_positive, len(list_PIL_positive))
        self.P = random.sample(list_PIL_positive, len(list_PIL_positive))
        self.N = random.sample(list_PIL_negative, len(list_PIL_negative))
        self.PN = self.P + self.N

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.A = self.A[::val_stride]
            self.P = self.P[::val_stride]
            self.N = self.N[::val_stride]
    
        elif val_stride > 0:
            del self.A[::val_stride]
            del self.P[::val_stride]
            del self.N[::val_stride]
    
    def preprocess(self, img_PIL):
        transform = torchvision.transforms.Compose([torchvision.transforms.Resize((160,160)),
                                                    torchvision.transforms.ToTensor()])
        if self.Normalize:
            transform = torchvision.transforms.Compose([torchvision.transforms.Resize((160,160)),
                                                        torchvision.transforms.ToTensor(),
                                                        torchvision.transforms.Normalize(mean=self.mean, std=self.std)])
        img = transform(img_PIL)
        return img


    def __len__(self):
        return len(self.P)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        PIL_imageA = self.A[idx]
        PIL_imageP = random.choice(self.PN)
        PIL_imageN = random.choice(self.PN)

        if self.Transform:
            imageA = self.preprocess(PIL_imageA)
            imageP = self.preprocess(PIL_imageP)
            imageN = self.preprocess(PIL_imageN)

            return imageA, imageP, imageN
        else :
            return PIL_imageA, PIL_imageP, PIL_imageN

BATCH_SIZE = 32
siameseDataset_train = SiameseDataset2(PIL_imgs['PIL_positif'], PIL_imgs['PIL_negatif'], val_stride=10, isValSet_bool=False, Transform=True, Normalize=False)
siameseDataset_test = SiameseDataset2(PIL_imgs['PIL_positif'], PIL_imgs['PIL_negatif'], val_stride=10, isValSet_bool=True, Transform=True, Normalize=False)
train_dataloader = DataLoader(siameseDataset_train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(siameseDataset_test, batch_size=BATCH_SIZE, shuffle=True)
model = InceptionResnetV1(pretrained='vggface2')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 3
model = InceptionResnetV1(pretrained='vggface2')
for param in model.parameters():
    param.requires_grad = False
model.last_linear.weight.requires_grad = True
def model_loop(model, epochs, trainloader, validloader, batch_size, anchor_img_, optimizer, triplet_loss, device):
    model.to(device)
    train_loss_list = []
    valid_loss_list = []
    size_train = len(trainloader.dataset)
    size_test = len(validloader.dataset)
    last_batch_size_train = size_train % batch_size
    last_batch_size_test = size_test % batch_size

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1} on {device} \n-------------------------------")
        
        train_loss = 0.0
        model.train()

        for batch, (anch, pos, neg) in enumerate(trainloader):
            # Transfer Data to GPU if available
            anch, pos, neg = anch.to(device), pos.to(device), neg.to(device)

            # Clear the gradients
            optimizer.zero_grad()

            # Make prediction & compute the mini-batch training loss            
            anch_embedding = model(anch)
            pos_embedding = model(pos)
            neg_embedding = model(neg)

            anch_embedding = anch_embedding  / torch.norm(anch_embedding)
            pos_embedding = pos_embedding / torch.norm(pos_embedding)
            neg_embedding = neg_embedding / torch.norm(neg_embedding)

            loss = triplet_loss(anch_embedding, pos_embedding, neg_embedding)

            # Compute the gradients
            loss.backward()

            # Update Weights
            optimizer.step()

            # Aggregate mini-batch training losses
            train_loss += loss.item()
            train_loss_list.append(train_loss)

            
            if batch == 0 or batch%10 == 0:
                loss, current = loss.item(), (batch+1) * len(pos)
                if len(pos) < batch_size:
                    current = (batch) * batch_size + len(pos)
                print(f"mini-batch loss for training : {loss:>7f}  [{current:>5d}/{size_train:>5d}]")
        
        # Compute the global training loss as the mean of the mini-batch training losses
        # print(f"Training loss for epoch {Epoch+1} = {train_loss/size_train}")
      
        valid_loss = 0.0
        model.eval()
        # Test part : no gradient update
        with torch.no_grad():
            for batch, (anch, pos, neg) in enumerate(validloader):
                # Transfer Data to GPU if available
                anch, pos, neg = anch.to(device), pos.to(device), neg.to(device)

                anch_embedding = model(anch)
                pos_embedding = model(pos)
                neg_embedding = model(neg)

                anch_embedding = anch_embedding / torch.norm(anch_embedding)
                pos_embedding = pos_embedding / torch.norm(pos_embedding)
                neg_embedding = neg_embedding / torch.norm(neg_embedding)

                loss = triplet_loss(anch_embedding, pos_embedding, neg_embedding)

                # Calculate Loss
                valid_loss += loss.item()
                valid_loss_list.append(valid_loss)

                if batch == 0 or batch%1 == 0:
                    loss, current = loss.item(), (batch+1) * len(pos)
                    if len(pos) < batch_size:
                        current = (batch) * batch_size + len(pos)
                    print(f"mini-batch loss for validation : {loss:>7f}  [{current:>5d}/{size_test:>5d}]")
        
        # Compute the global training & validation loss as the mean of the mini-batch losses
        train_loss /= len(trainloader)
        valid_loss /= len(validloader)
        print(f"--Fin Epoch {epoch+1}/{epochs} \n Training Loss: {train_loss:>7f} \n Validation Loss: {valid_loss:>7f}" )
        print('\n')

    return train_loss_list, valid_loss_list
train_loss, valid_loss = model_loop(model = model, model2=model2.to(device),
                                    epochs = num_epochs,
                                    trainloader = train_dataloader,
                                    validloader = val_dataloader,
                                    batch_size = BATCH_SIZE, 
                                    anchor_img_ = anchor_t,
                                    optimizer = optimizer,
                                    triplet_loss = nn.TripletMarginLoss(alpha=0.2),
                                    device = device)

Usually when TL ends up with margin value it means model output went to 0 for any input. Also I see that you random sampling positive and negative examples, but in my experience best strategy is semi-hard + hard, skipping easy.

Hi thx for your feedback on that.
Could you elaborate, I’m not sure I understand what you mean by

Also I see that you random sampling positive and negative examples, but in my experience best strategy is semi-hard + hard, skipping easy.

Thx !

When you select triplets, the idea is that model should learn something from selection. If you select triplets with d(A, N) > d(A, P) + margin, loss is 0 => model has nothing to learn from it.

Things that I don’t understand is how d(A,N) could be greater than d(A,P)+margin if the model hasn’t learn anything yet ? Okey the images A, N and P are different but nothing tells me that f(A), f(N) and f(P) will behave like you said. Yet, the loss is zero so you’re right but I don’t understand why.
A model that hasn’t learn anything should not be able to classify my images like d(f(A), f(N) > d(f(A), f(P)) no ?

Thanks again :slight_smile:

The idea behind triplet loss is that your model projects inputs into some position in latent vector space. In the beginning, when model hasn’t learn anything yet, these positions are pretty much random, so outputs are randomly scattered across this space. Due to randomness some negatives may appear near anchor, some will end up far away. Triplet loss is pushing negative examples away from anchor-positive pair, penalizing if negative’s position is closer than d(A,P) plus some margin.

Ok I read your link and your answers. I’m starting to understand why it cannot work.
So I need semi-hard ans hard triplets…

But I’m stuck here. How can I create those triplets ? Should I modify my images ? Could I generate hard triplets based on my small dataset ? Since my models are not trained, how can I say if one image is an hard negative (I won’t look at each images and say “hummm this guy seems to be me but it’s not me” will I ? ^^)

Thanks !

You need to do extra inference pass before training, gather outputs, sort triplets out, then run training with these triplets. And this is rather complicated part, since every time you update weights your outputs change and all old outputs calculated before for combining triplet pairs become outdated.

So I imagine you are referring to the triplet mining ?
What you’re saying is that I need to evaluate my model on triplet at each step of the training part and reconstruct my triplets ?

Do you have some document to help me with this task ?

This post is the best I know of:

1 Like

I am also facing the same problem that my triplet loss is struck at margin value and it remained constant. I am working with sketchANet network and Shoe V2 dataset using pytorch. this dataset has no label