Hi, I was trying to use the triplet loss function in order to train a semantic segmentation model.

```
class TripletLoss(nn.Module):
def __init__(self, margin=None):
super(TripletLoss, self).__init__()
self.margin = margin
if self.margin is None:
self.Loss = nn.SoftMarginLoss()
else:
self.Loss = nn.TripletMarginLoss(margin=margin, p=2)
def forward(self, anchor, pos, neg):
if self.margin is None:
num_samples = anchor.shape[0]
y = torch.ones((num_samples, 1)).view(-1)
if anchor.is_cuda: y = y.cuda()
ap_dist = torch.norm(anchor-pos, 2, dim=1).view(-1)
an_dist = torch.norm(anchor-neg, 2, dim=1).view(-1)
loss = self.Loss(an_dist - ap_dist, y)
else:
loss = self.Loss(anchor, pos, neg)
return loss
```

My question is how I can sample triplets of pixel embeddings for the calculation of this loss. Does anyone know if there is an implementation of a pixel-wise triplet loss for dense tasks such image segmentation?