Fine-grained control over gradient computation

Hi,
I am doing contrastive learning in a setting similar to CLIP (i.e. I have n correct pairs of samples in a batch and the task is to find which of the n^2 possible pairs are correct). This is done by computing cosine similarities of the embeddings and then applying cross entropy loss. Since I suspect that a significant part of the information is available in only one of the two modalities, I want to add an additional link between the modalities after a part that is the same for all pairs and before the loss is computed. Unfortunately, this means that more computation needs to be done for all the n^2 pairs rather than just the n individual samples.
In the inference setting, this is doable if I compute the cosine similarities in sub-batches. However, when training, autograd stores the information for all n^2 pairs that is needed for gradient computation. This severely limits the batch size, which should actually be as large as possible for contrastive learning.
I could still do larger batches if I could compute the gradients for all pairs where one of the samples is the same as in a given correct pair at a time and aggregate the gradients. I know that it is possible to compute per sample gradients (I saw https://pytorch.org/tutorials/intermediate/per_sample_grads.html), but that would be inefficient because all the gradients for the common part of the model would need to be calculated for each correct pair.
So my question is: Can I compute the gradients per pair in the part with links between the modalities, aggregate it and then compute the gradients for the rest of the model and feed them to the optimizer?

It would be ideal if I could do this without affecting the way the model needs to be handled for training, because I use Pytorch Lightning and want to keep using it.

Hey,
The process you’re describing sounds awfully similar to a technique called GradCache (Paper, Implementation). They basically first calculate the representation gradients with respect to the loss function (without calculating the model gradients) then later accumulate the model gradients for smaller sub-batches that can fit into memory while utilising the cached gradients. finally they optimize. I think they do support pytorch lightning but I’m not sure.