Gradient accumulation for contrastive loss

I am trying to implement gradient accumulation for large batch training. In the case of categorical cross-entropy loss, I would implement gradient accumulation in the following way:

criterion = cross_entropy_loss
accumulation_steps = 5

for idx, (x, y) in enumerate(data_loader, 1):
    output = model(x)
    loss = criterion(output, y) / accumulation_steps

    if idx % accumulation_steps == 0:

I am unable to use the same strategy using contrastive loss. The contrastive loss for a sample depends on all the samples for that mini-batch. In the context of accumulated mini-batches, the loss for a single sample depends on the outputs from all the samples obtained across the accumulated mini-batches. This seems to be non-trivial for me.

  • One option is to save all the outputs output = model(x) and perform loss.backward() once every accumulation_steps. However, this would lead to memory issues on GPUs.

I wonder, how accumulation can be performed in such a scenario? Looking forward to your assistance. Thanks in advance!