Torch.distributed.all_gather() to compute Noise Contrastive Loss with PyTorch

Computing infoNCE requires gathering all encoded representations from all GPUs for full negative sampling. Specifically, to compute infoNCE loss, many repositories, e.g. SimCLR, essl, uses 1) torch.distributed.all_gather() to gather features from all GPUs at forward() and 2) torch.distributed.all_reduce() to sum the gradients at backward().

However, wouldn’t the synchronized representations (using torch.distributed.all_gather()) make all GPUs compute the exact same loss at all GPUs and are thus redundant?

How about sending all features to a single machine (either GPU or CPU), computing the loss there, and multiplying the gradient by the # of GPUs? because gradient will be averaged throughout GPUs

That’s a fair concern. I don’t know if the program uses DDP. With DDP, since the data input to each GPU would be different, maybe the loss computed at each GPU would be different? And hence a need to all-reduce the gradients?

Input to GPUs are obviously different. However, i think the representation from the encoder becomes the same due to all_reduce()(most infoNCE self-sup repositories synchronize the representation to use all representation for the negative sample). Therefore, i assume gathering all representation to a single machine, eg gpu or cpu, may reduce the redundant loss computation.

And my question is if my concern is correct or not. Thank you!

I would agree with the concern. This can be done a little differently such that each GPU only calculate the loss for its own batch after all_gather, that way each GPU is doing similar computing, and gradient reduce would carry to all GPUs.