Training strategy for triplet loss


I’m trying to train a triplet loss model and I wonder if am on the right track on preparing triplets and batches. Using pytorch implementation, TripletMarginLoss. A long post, sorry about that.

My data consists of variable length short documents. Each document is labeled with a class (almost 50K docs and 1000 classes). I first encode those documents such that each has a fixed-length vector representation. As you can infer, inputs of the loss function are those vectors.

Of course, the main motivation is to increase distance between documents in different classes. It may sound like an easy task (and maybe a good fit for multi-class classification). However, documents of some classes tend to be closer to each other and classification was not good enough at distinguishing those. In other words, for each class there is a list of classes, where inter-class documents are closer to each other (and they shouldn’t be). My point is to make those docs closer to ones in their own classes, and relatively distant from ones in similar classes. I do generate triplets based on this idea (order is not important, just an example):

doc_1 (class a, anchor), doc_2 (class a, positive), doc_3 (class b, negative),
doc_2 (class a, anchor), doc_1 (class a, positive), doc_4 (class c, negative)

I tested this idea with 40000 triplets, batch_size=4, Adam optimizer and gradient clipping (loss exploded otherwise) and margin=1.0. My encoder is simple deep averaging network (encoder is out of scope of this post). In each batch there are 12 documents w.r.t batch size (4 anchor, 4 positive, 4 negative). After 30 epochs, training loss (average of all batches) was almost 0.3 and test loss was around 0.35. After 50, training was 0.04 and test was around 0.33.

for batch in batches:
    anchor_encoded = self.model(anchor_docs)
    positive_encoded = self.model(positive_docs)
    negative_encoded = self.model(negative_docs)
    loss = self.criterion(anchor_encoded, positive_encoded, negative_encoded)

So this first try tells me that it is not that bad, but of course number of triplets is low compared to dataset size (50K docs, 35K in training). Nonetheless, I wonder:

  1. For each batch, should there be only one class type as (anchor, positive) in batches? I don’t do that, since some classes are small in terms of number of docs. Just pick from the triplets I constructed beforehand and divide them into batches. I already have a target (hard and semi-hard triplets), so I just created a list of them. Or should I generate large batches where all anchors are from a different class? This would cause 1000-sized batches. I ask this because, guides I’ve read so far confused me about generating triplets and batches.

  2. Would adding random triplets (no relation between classes) be good?

  3. How many triplets I should generate compared to dataset size?

  4. I basically assume that docs in a class are similar to each other, but there can be differences in reality. Would using TF-IDF or other metric when generating triplets be helpful?

I’d appreciate any comment.

Thanks in advance.

Hi, if you still monotring the pytorch forums I had a similar problem of training model on triplet loss were you able to get through this project.If so can you please give me some pointers