Applying Centroid Triplet Loss on OSNet

Hello, I am trying to apply Centroid Triplet Loss on OSNet.
I cloned torchreid from https://github.com/KaiyangZhou/deep-person-reid

Here is the code for the loss function
from https://github.com/KevinMusgrave/pytorch-metric-learning/issues/374

from collections import defaultdict

# import numpy as np
import torch

from pytorch_metric_learning.reducers import AvgNonZeroReducer
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.losses import BaseMetricLossFunction
from pytorch_metric_learning.losses import TripletMarginLoss


def concat_indices_tuple(x):
    return [torch.cat(y) for y in zip(*x)]


class CentroidTripletLoss(BaseMetricLossFunction):
    def __init__(
        self,
        margin=0.05,
        swap=False,
        smooth_loss=False,
        triplets_per_anchor="all",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.triplet_loss = TripletMarginLoss(
            margin=margin,
            swap=swap,
            smooth_loss=smooth_loss,
            triplets_per_anchor=triplets_per_anchor,
            **kwargs
        )

    def compute_loss(
        self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None
    ):
        c_f.indices_tuple_not_supported(indices_tuple)
        c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
        """
        "During training stage each mini-batch contains P distinct item
        classes with M samples per class, resulting in batch size of P x M."
        """
        masks, class_masks, labels_list, query_indices = self.create_masks_train(labels)

        P = len(labels_list)
        M = max([len(instances) for instances in labels_list])
        DIM = embeddings.size(-1)

        """
        "...each sample from Sk is used as a query qk and the rest 
        M - 1 samples are used to build a prototype centroid"
        i.e. for each class k of M items, we make M pairs of (query, centroid),
        making a total of P*M total pairs.
        masks = (M*P x len(embeddings)) matrix
        labels_list[i] = indicies of embeddings belonging to ith class
        centroids_emd.shape == (M*P, DIM)
        i.e.    centroids_emb[0] == centroid vector for 0th class, where the first embedding is the query vector
                centroids_emb[1] == centroid vector for 0th class, where the second embedding is the query vector
                centroids_emb[M+1] == centroid vector for 1th class, where the first embedding is the query vector
        """

        masks_float = masks.type(embeddings.type()).to(embeddings.device)
        class_masks_float = class_masks.type(embeddings.type()).to(embeddings.device)
        inst_counts = masks_float.sum(-1)
        class_inst_counts = class_masks_float.sum(-1)

        valid_mask = inst_counts > 0
        padded = masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)
        class_padded = class_masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)

        positive_centroids_emb = padded.sum(-2) / inst_counts.masked_fill(
            inst_counts == 0, 1
        ).unsqueeze(-1)

        negative_centroids_emb = class_padded.sum(-2) / class_inst_counts.masked_fill(
            class_inst_counts == 0, 1
        ).unsqueeze(-1)

        query_indices = torch.tensor(query_indices).to(embeddings.device)
        query_embeddings = embeddings.index_select(0, query_indices)
        query_labels = labels.index_select(0, query_indices)
        assert positive_centroids_emb.size() == (M * P, DIM)
        assert negative_centroids_emb.size() == (P, DIM)
        assert query_embeddings.size() == (M * P, DIM)

        query_indices = query_indices.view((P, M)).transpose(0, 1)
        query_embeddings = query_embeddings.view((P, M, -1)).transpose(0, 1)
        query_labels = query_labels.view((P, M)).transpose(0, 1)
        positive_centroids_emb = positive_centroids_emb.view((P, M, -1)).transpose(0, 1)
        valid_mask = valid_mask.view((P, M)).transpose(0, 1)

        labels_collect = []
        embeddings_collect = []
        tuple_indices_collect = []
        starting_idx = 0
        for inst_idx in range(M):
            one_mask = valid_mask[inst_idx]
            if torch.sum(one_mask) > 1:
                anchors = query_embeddings[inst_idx][one_mask]
                pos_centroids = positive_centroids_emb[inst_idx][one_mask]
                one_labels = query_labels[inst_idx][one_mask]

                embeddings_concat = torch.cat(
                    (anchors, pos_centroids, negative_centroids_emb)
                )
                labels_concat = torch.cat(
                    (one_labels, one_labels, query_labels[inst_idx])
                )
                indices_tuple = get_all_triplets_indices(labels_concat)

                """
                Right now indices tuple considers all embeddings in
                embeddings_concat as anchors, pos_example, neg_examples.
                1. make only query vectors be anchor vectors
                2. make pos_centroids be only used as a positive example
                3. negative as so
                """
                # make only query vectors be anchor vectors
                indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]

                # make only pos_centroids be postive examples
                indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
                indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]

                # make only neg_centroids be negative examples
                indices_tuple = [
                    x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple
                ]

                tuple_indices_collect.append(indices_tuple)
                embeddings_collect.append(embeddings_concat)
                labels_collect.append(labels_concat)
                starting_idx += len(labels_concat)

        indices_tuple = concat_indices_tuple(tuple_indices_collect)

        if len(indices_tuple) == 0:
            return self.zero_losses()

        final_embeddings = torch.cat(embeddings_collect)
        final_labels = torch.cat(labels_collect)

        loss = self.triplet_loss.compute_loss(
            final_embeddings, final_labels, indices_tuple, ref_emb=None, ref_labels=None
        )
        return loss
    
    def create_masks_train(self, class_labels):
        labels_dict = defaultdict(list)
        class_labels = class_labels.detach().cpu().numpy()
        for idx, pid in enumerate(class_labels):
            labels_dict[pid].append(idx)

        unique_classes = list(labels_dict.keys())
        labels_list = list(labels_dict.values())
        lens_list = [len(item) for item in labels_list]
        # lens_list_cs = np.cumsum(lens_list)

        M = max(len(instances) for instances in labels_list)
        P = len(unique_classes)

        query_indices = []
        class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
        masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
        for class_idx, class_insts in enumerate(labels_list):
            class_masks[class_idx, class_insts] = 1
            for instance_idx in range(M):
                matrix_idx = class_idx * M + instance_idx
                if instance_idx < len(class_insts):
                    query_indices.append(class_insts[instance_idx])
                    ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
                    masks[matrix_idx, ones] = 1
                else:
                    query_indices.append(class_insts[0])
        return masks, class_masks, labels_list, query_indices


def get_all_triplets_indices(labels, ref_labels=None):
    matches, diffs = get_matches_and_diffs(labels, ref_labels)
    triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
    return torch.where(triplets)


def get_matches_and_diffs(labels, ref_labels=None):
    if ref_labels is None:
        ref_labels = labels
    labels1 = labels.unsqueeze(1)
    labels2 = ref_labels.unsqueeze(0)
    matches = (labels1 == labels2).byte()
    diffs = matches ^ 1
    if ref_labels is labels:
        matches.fill_diagonal_(0)
    return matches, diffs

When I train the model with this loss function, the loss is around 0.007 and doesn’t change.

I would really like to fix this issue.

Any help is appreciated. Thank you!

Here is the link for the original CTL Model
https://github.com/mikwieczorek/centroids-reid