Computing infoNCE requires gathering all encoded representations from all GPUs for full negative sampling. Specifically, to compute infoNCE loss, many repositories, e.g. SimCLR, essl, uses 1) torch.distributed.all_gather() to gather features from all GPUs at forward() and 2) torch.distributed.all_reduce() to sum the gradients at backward().
However, wouldn’t the synchronized representations (using torch.distributed.all_gather()) make all GPUs compute the exact same loss at all GPUs and are thus redundant?
That’s a fair concern. I don’t know if the program uses DDP. With DDP, since the data input to each GPU would be different, maybe the loss computed at each GPU would be different? And hence a need to all-reduce the gradients?
Input to GPUs are obviously different. However, i think the representation from the encoder becomes the same due to all_reduce()(most infoNCE self-sup repositories synchronize the representation to use all representation for the negative sample). Therefore, i assume gathering all representation to a single machine, eg gpu or cpu, may reduce the redundant loss computation.
And my question is if my concern is correct or not. Thank you!
I would agree with the concern. This can be done a little differently such that each GPU only calculate the loss for its own batch after all_gather, that way each GPU is doing similar computing, and gradient reduce would carry to all GPUs.