Each batch is divided into smaller parts and distributed across the different GPUs, and each GPU contains only a certain partition of the full batch. After each GPU computes the embedding for each input, I want to recombine them so I can properly compute the contrastive loss. I know that I have to use
dist.all_gather() to achieve that and that this function does not maintain the
grad_fn property in the combined data.
I have found two approaches that address this issue.
torch.autograd.Function, such as:
- In github.com/Spijkervet/SimCLR: this is the implementation of the loss and this is its autograd.Function.
- In github.com/PyTorchLightning: this is the implementation of the loss and this is its autograd.Function.
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py def nt_xent_loss(self, ...): ... if self.world_size > 1: z = torch.cat(GatherLayer.apply(z), dim=0) sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature ...
# https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py class GatherLayer(torch.autograd.Function): """Gather tensors from all process, supporting backward propagation.""" @staticmethod def forward(ctx, input): ctx.save_for_backward(input) output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] dist.all_gather(output, input) return tuple(output) @staticmethod def backward(ctx, *grads): (input,) = ctx.saved_tensors grad_out = torch.zeros_like(input) grad_out[:] = grads[dist.get_rank()] return grad_out
I’ve seen this approach described by JohnGiorgi in this comment. The idea is to use the local tensor that has the proper
grad_fn to overwrite the corresponding partition in the list returned by
(I am copying his pseudo-code here to make the question self-contained)
# From: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10 import torch import torch.distributed as dist # Dummy code representing the forward pass for some batch of text on one replica. embeddings = model(batch) # Gather the embeddings from every replica. embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())] dist.all_gather(embeddings_list, embeddings) # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica with the embeddings produced on this replica, which do have gradients. embeddings_list[dist.get_rank()] = embeddings # Finally, concatenate the list of embeddings before computing a loss. embeddings = torch.cat(embeddings_list) # I didn't demonstrate how to generate the labels, this will be task-dependent. loss = some_contrastive_loss(embeddings, labels)
Also, this blog post claims that when using this approach, we should re-scale the gradients by the number of GPUS (i.e., word_size) to obtain the correct gradients.
Which one should I use? Is it just a matter of style or are there meaningful differences between these approaches?