Batch Hard Triplet Loss: Backpropagation fails due to Function 'SqrtBackward' returned nan values in its 0th output

Hi guys!

I have been trying to implement this paper which mentions triplet loss with batch hard mining for facial recognition.

Based on my understanding of the paper, I have written the loss function as follows

# https://arxiv.org/pdf/1703.07737.pdf
import torch
import torch.nn as nn

class batchHardTripletLoss(nn.Module):
    def __init__(self, margin = 0.2, squared = False, agg = "mean"):
        """
        Initialize the loss function with a margin parameter, whether or not to consider
        squared Euclidean distance and how to aggregate the loss in a batch
        """
        super(batchHardTripletLoss, self).__init__()
        self.margin = margin
        self.squared = squared
        self.agg = agg
        self.eps = 1e-8
    
    def get_pairwise_distances(self, feat_vecs):
        """
        Computing distance for every pair using 
        (a - b) ^ 2 = a^2 - 2ab + b^2 
        """
        ab = feat_vecs.mm(feat_vecs.t())
        a_squared = ab.diag().unsqueeze(1)
        b_squared = ab.diag().unsqueeze(0)
        distances = a_squared - 2 * ab + b_squared
        distances.clamp(min = self.eps)
        
        if not self.squared:
            distances = torch.sqrt(distances + self.eps)

        return distances
            
        
    def get_mask(self, labels, type_ = "positive"):
        """
        Get a binary matrix corresponding to valid duplet pairs for 
        (anchor, positive) & (anchor, negative) pairs
        """
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        PK = labels.shape[1]
        mask = torch.zeros(PK, PK).to(DEVICE)
        
        for idx, item in enumerate(labels[0]):
            for inner_idx, inner_item in enumerate(labels[0]):
                
                if type_ == "positive":
                    
                    # Labels should match and the image index shouldn't be the same
                    if (item == inner_item) and (idx != inner_idx):
                        mask[idx, inner_idx] = 1
                elif type_ == "negative":
                    
                    # Labels must be different and image index shouldn't be the same (redundant but still...)
                    if (item != inner_item) and (idx != inner_idx):
                        mask[idx, inner_idx] = 1
                    
        return mask
    
    def forward(self, feat_vecs, labels):
        """
        Define the loss function implementation here
        """
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

        # Get the pairwise distances of all images from one another
        distances = self.get_pairwise_distances(feat_vecs)
        
        # Get the toughest positive pair by first filtering out the (anchor, positive)
        # pairs using the get_mask routine and then find the max across rows
        positive_mask = self.get_mask(labels, type_ = "positive").to(DEVICE)
        toughest_positive_distance = (distances * positive_mask).max(dim = 1)[0]
        
        negative_mask = self.get_mask(labels, type_ = "negative").to(DEVICE)
        
        # Add the maxiumum negative distance to all the non-valid pairs
        # on a rowwise basis and then out of them whichever is the minimum 
        # will be our pair distance corresponding to toughest (anchor, negative) pair
        max_negative_dist = distances.max(dim=1,keepdim=True)[0]
        distances = distances + max_negative_dist * (1 - negative_mask).float()
        toughest_negative_distance = distances.min(dim = 1)[0]
        
        # Find the triplet loss by using the two distances obtained above
        triplet_loss = (toughest_positive_distance - toughest_negative_distance + self.margin).clamp(min = self.eps)
        
        # Aggregate the loss to mean/sum based on the initialization of the loss function
        if self.agg == "mean":
            triplet_loss = triplet_loss.mean()
        elif self.agg == "sum":
            triplet_loss = triplet_loss.sum()
            
        return triplet_loss

However when I go to training, the training runs smoothly for the first few epochs but later, it starts throwing RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output. this error. Here’s a screenshot of the same.

I tried to follow this thread and added a small epsilon quantity before taking square root to make sure it’s non-zero but it still didn’t work…

Can someone please guide me on how I can tackle this issue? I would be highly obliged.

Thanks :smiley: !

The torch clamp function has a different gradient than what you are expecting here. Check this post : Pytorch Autograd gives different gradients when using .clamp instead of torch.relu. Maybe try using relu instead of clamp.

1 Like

Hi @karanjeswani ,

The issue you pointed to helped!

After substituting torch.clamp with nn.ReLU, the training worked fine as seen above!

# Find the triplet loss by using the two distances obtained above

# Previously in the loss function:
distances = distances.clamp(min = self.eps)
triplet_loss = (toughest_positive_distance - toughest_negative_distance + self.margin).clamp(min = self.eps)

# Currently in the loss function:
distances = nn.ReLU()(distances)
triplet_loss = nn.ReLU()(toughest_positive_distance - toughest_negative_distance + self.margin)
        

Thanks a ton!

1 Like