I have been trying to implement this paper which mentions triplet loss with batch hard mining for facial recognition.
Based on my understanding of the paper, I have written the loss function as follows
# https://arxiv.org/pdf/1703.07737.pdf import torch import torch.nn as nn class batchHardTripletLoss(nn.Module): def __init__(self, margin = 0.2, squared = False, agg = "mean"): """ Initialize the loss function with a margin parameter, whether or not to consider squared Euclidean distance and how to aggregate the loss in a batch """ super(batchHardTripletLoss, self).__init__() self.margin = margin self.squared = squared self.agg = agg self.eps = 1e-8 def get_pairwise_distances(self, feat_vecs): """ Computing distance for every pair using (a - b) ^ 2 = a^2 - 2ab + b^2 """ ab = feat_vecs.mm(feat_vecs.t()) a_squared = ab.diag().unsqueeze(1) b_squared = ab.diag().unsqueeze(0) distances = a_squared - 2 * ab + b_squared distances.clamp(min = self.eps) if not self.squared: distances = torch.sqrt(distances + self.eps) return distances def get_mask(self, labels, type_ = "positive"): """ Get a binary matrix corresponding to valid duplet pairs for (anchor, positive) & (anchor, negative) pairs """ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" PK = labels.shape mask = torch.zeros(PK, PK).to(DEVICE) for idx, item in enumerate(labels): for inner_idx, inner_item in enumerate(labels): if type_ == "positive": # Labels should match and the image index shouldn't be the same if (item == inner_item) and (idx != inner_idx): mask[idx, inner_idx] = 1 elif type_ == "negative": # Labels must be different and image index shouldn't be the same (redundant but still...) if (item != inner_item) and (idx != inner_idx): mask[idx, inner_idx] = 1 return mask def forward(self, feat_vecs, labels): """ Define the loss function implementation here """ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Get the pairwise distances of all images from one another distances = self.get_pairwise_distances(feat_vecs) # Get the toughest positive pair by first filtering out the (anchor, positive) # pairs using the get_mask routine and then find the max across rows positive_mask = self.get_mask(labels, type_ = "positive").to(DEVICE) toughest_positive_distance = (distances * positive_mask).max(dim = 1) negative_mask = self.get_mask(labels, type_ = "negative").to(DEVICE) # Add the maxiumum negative distance to all the non-valid pairs # on a rowwise basis and then out of them whichever is the minimum # will be our pair distance corresponding to toughest (anchor, negative) pair max_negative_dist = distances.max(dim=1,keepdim=True) distances = distances + max_negative_dist * (1 - negative_mask).float() toughest_negative_distance = distances.min(dim = 1) # Find the triplet loss by using the two distances obtained above triplet_loss = (toughest_positive_distance - toughest_negative_distance + self.margin).clamp(min = self.eps) # Aggregate the loss to mean/sum based on the initialization of the loss function if self.agg == "mean": triplet_loss = triplet_loss.mean() elif self.agg == "sum": triplet_loss = triplet_loss.sum() return triplet_loss
However when I go to training, the training runs smoothly for the first few epochs but later, it starts throwing
RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output. this error. Here’s a screenshot of the same.
I tried to follow this thread and added a small epsilon quantity before taking square root to make sure it’s non-zero but it still didn’t work…
Can someone please guide me on how I can tackle this issue? I would be highly obliged.