Triplet loss for semantic segmentation

Hi, I was trying to use the triplet loss function in order to train a semantic segmentation model.

class TripletLoss(nn.Module):
   def __init__(self, margin=None):
     super(TripletLoss, self).__init__()
     self.margin = margin
     if self.margin is None:  
        self.Loss = nn.SoftMarginLoss()
     else:
        self.Loss = nn.TripletMarginLoss(margin=margin, p=2)

   def forward(self, anchor, pos, neg):
     if self.margin is None:
        num_samples = anchor.shape[0]
        y = torch.ones((num_samples, 1)).view(-1)
        if anchor.is_cuda: y = y.cuda()
        ap_dist = torch.norm(anchor-pos, 2, dim=1).view(-1)
        an_dist = torch.norm(anchor-neg, 2, dim=1).view(-1)
        loss = self.Loss(an_dist - ap_dist, y)
     else:
        loss = self.Loss(anchor, pos, neg)

     return loss

My question is how I can sample triplets of pixel embeddings for the calculation of this loss. Does anyone know if there is an implementation of a pixel-wise triplet loss for dense tasks such image segmentation?