CLIP Model Batching When Batch Size Limited

Problem: I am training a CLIP style model for a pet project that is connecting positive pairs of image-caption data. Notably, the images are passed through a vision transformer and the text through an LLM to create an embedding for each, which are then projected to the same dimension for the CLIP loss.

I am fortunate to have an 8 x L40S GPU server set up with 48 GB of GDDR6 memory per GPU. However, the size of the LLMs limits the batch size (of image and text data) to about 4 per GPU.

This normally wouldn’t be a problem in a distributed setting as gradient accumulation, or other training tricks would work fine with a traditional loss such as MSE. However, CLIP models operate by create synthetic negative date. Essentially, every batch, B, is expanded to a B x B matrix, where the correct pairs of data lie on the diagonal and every other combination of input (so every image with every caption and vice versa) are considered “negative samples.” Then, the distance (usually cosine) is calculated between pairs, where the loss wants to minimize the distance between positive pairs and maximize all other pairs. This functionally can be achieved by treating the distances as logits and using cross entropy loss across the columns and rows.

The problem is that CLIP training benefits from a large batch size because the larger the batch, the more synthetic negatives can be created. So, in my case, a batch size of 4 yields a 4 x 4 matrix = 16 data points per GPU before loss calculation and weight update. I want this to be a lot larger…for example, 1,024 or 2,048 for a batch size. In the current DDP setting, this is not implemented.

Note - I cannot pre-compute the embeddings from each transformer because I would like to update the weights of some of the last few transformer blocks (fine-tuning) to yield better performance on my downstream application.

Potential Solution: One idea I had to solve this is to let model accumulate embeddings. So, if the batch size of 4 fits on each GPU, then simply run 32 batches of data (32 * 4 = 128 embeddings per GPU) then collect from all GPUs (128 * 8 = 1,024) to get a large batch, then (because the embeddings are small enough, compute the 1,024 x 1,024 similarity matrix for the CLIP loss) Finally, using that loss value, call loss.backward() through the data on each GPU to get the gradients which are then aggregated to update the LLM and ViT parameters.

My Question: For my potential solution, I am wondering if this is the proper approach for this situation or if there is a better approach to take. Furthermore, I am unsure how to implement this in practice. How do I collect embeddings from each GPU to form the CLIP matrix? By calling loss.backward() after computing this loss, is the compute graph still intact even in a DDP setting? I am less familiar with how DDP works on the backend.

Any advice would be appreciated!