Dist.all_gather() and gradient preservation in multi-GPU training

I want to use the NT-Xent loss from the SimCLR paper and I am unsure about what is the correct implementation in a multi-GPU setting, specifically how to properly use dist.all_gather().

Each batch is divided into smaller parts and distributed across the different GPUs, and each GPU contains only a certain partition of the full batch. After each GPU computes the embedding for each input, I want to recombine them so I can properly compute the contrastive loss. I know that I have to use dist.all_gather() to achieve that and that this function does not maintain the grad_fn property in the combined data.

I have found two approaches that address this issue.

#1 Extend torch.autograd.Function, such as:

  1. In github.com/Spijkervet/SimCLR: this is the implementation of the loss and this is its autograd.Function.
  2. In github.com/PyTorchLightning: this is the implementation of the loss and this is its autograd.Function.
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py
def nt_xent_loss(self, ...):
        ...

        if self.world_size > 1:
            z = torch.cat(GatherLayer.apply(z), dim=0)

        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
        ...
# https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py
class GatherLayer(torch.autograd.Function):
    """Gather tensors from all process, supporting backward propagation."""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

#2 Overwrite the gathered list using the partition from the local replica

I’ve seen this approach described by JohnGiorgi in this comment. The idea is to use the local tensor that has the proper grad_fn to overwrite the corresponding partition in the list returned by dist.all_gather().

(I am copying his pseudo-code here to make the question self-contained)

# From: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10
import torch
import torch.distributed as dist

# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)

# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings

# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)

# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)

Also, this blog post claims that when using this approach, we should re-scale the gradients by the number of GPUS (i.e., word_size) to obtain the correct gradients.

Question

Which one should I use? Is it just a matter of style or are there meaningful differences between these approaches?

7 Likes

Hi, did you find the difference? I am also troubled by this question

I’m also currently investigating a multi-GPU contrastive setup and I’m currently debugging the difference between a 1 GPU full bs and x GPU bs/x setup to find the difference between the loss implementations.

Therefore, I would be very curious on how others approached this and with what you ended up in the end. :slight_smile:

For the three approaches that you provided, in my opinion, the SimCLR loss implemented by Spijkervet is the same as the overwriting method of the gathered list. The loss implemented by PyTorchLightning is different, which all_reduce the gradient in backward function, thus the gradient will be “dist.word_size()” times than the former implementation. I do not know which one is right.