Accumulating Batches for Contrastive Loss

I have a custom dataset in which each example is fairly large (batch, 80, 105, 90)). I am training a self-supervised model with a contrastive loss that requires a decently large batch size (say 128). My problem is that only 2 examples fit into GPU memory at once. However, before computing the loss, my model reduces the example to something of shape (batch, 700). Does it make sense to accumulate these latent examples (which should fit into memory) and then compute my loss with a bigger batch size? The idea would go something like this:

# Training loop
bundle = (next(loader) for _ in range(accumulate))
latent = []
for pre_batch in bundle:
  latent += [model(pre_batch)]

latent = torch.cat(latent)
loss = contrastive_loss(latent)

optimizer.zero_grad()
loss.backward()
optimizer.step()

Based on this post my understanding is that this scenario is the “unlucky #3”, where both computation and memory are high (in fact, when I run this code I face a CUDA OOM error). Is there a smarter way of doing this or I a missing something trivial here?

It wouldn’t matter how small the output tensors are since the entire computation graph is stored in each iteration including all intermediates.

Yes, you are right that this approach will take a lot of memory, so would calling backward in each batch and accumulating the gradients work?

Hi @ptrblck! Thanks so much for the reply!

@ptrblck
Yes, you are right that this approach will take a lot of memory, so would calling backward in each batch and accumulating the gradients work?

As I see it there is a problem in calling backward (I am assuming from loss.backward()) in each batch and simply start accumulating the gradients, but please correct me if this is nonsense. In this way I would compute each gradient with a batch_size of 2 given my GPU memory constraints, and because the contrastive loss (which is heavily inspired from the CLIP loss) depends on the batch_size not in a trivial averaging fashion (and the quality of the learning signal degrades for smaller sizes), I fear that in the standard gradient accumulation technique I would be averaging “very poor” gradients resulting in overall poor performances. This was the resoning that prompted the need for accumulating the batches and not simply the gradients.

For the time being I came up with a simple workaround: I put several torch.utils.checkpoint.checkpoint throughout the model and managed to fit a batch_size of 100, which is close to what I was aiming for, of couse paying the additional computation cost.

I think your explanation makes sense and it’s good to hear checkpointing could help.