Distributed InfoNCE Loss (CLIP)

I am trying to implement InfoNCE Loss from CLIP in a distributed way. InfoNCE is a loss function which is used for contrastive learning and it favors large batch size during calculation. In CLIP, a batch is composed of image-text pairs, there is an image encoder and a text encoder. These encoders are then used to extract image and text embeddings, further these embeddings are used to calculate the pairwise cosine similarity and finally the loss.

In the original CLIP paper 256 GPUs with batch size of 128 per GPU is used. Also, authors mentioned that they calculate the pairwise cosine similarity using all the batches across all GPUs, you can refer to this issue. If we use the vanilla DistributedDataParallel then the cosine similarity and the loss would be calculated only the pairs within a single GPU and later their gradients would be averaged and synced. What we really want is to gather all embeddings across all GPUs and calculate cosine similarity for a given batch across all the other batches like it’s mentioned in the issue.

For this purpose, I started writing some dummy code to test this idea on CPU before spending GPU compute. I created 2 gists, 1 script as a test case which does loss calculation and backward without distributed context and 1 other script which tries to implement this idea in distributed context. Although, it looks like image_encoder.weight.grad correctly matches the non-distributed version text_encoder.weight.grad is None.

Regarding above, I have a few questions:

  1. Are there already any existing code or best practices for implementing distributed contrastive losses which requires communicating model outputs across devices before loss calculation?
  2. If 1) is a no, then how can I fix the issue I see in my gists? Is it a good start for implementing this idea?
  3. Currently I use dist.all_gather since this is a dummy example, but in a real training setting would using dist.gather allow to handle possible OOM issues better?

Thanks a lot in advance!

1 Like

Update: I updated code here to use GatherLayer. Now text_encoder.weight.grad exists but values are not same as the non-distributed version.