I want to use the NT-Xent loss from the SimCLR paper and I am unsure about what is the correct implementation in a multi-GPU setting, specifically how to properly use `dist.all_gather()`

.

Each batch is divided into smaller parts and distributed across the different GPUs, and each GPU contains only a certain partition of the full batch. After each GPU computes the embedding for each input, I want to recombine them so I can properly compute the contrastive loss. I know that I have to use `dist.all_gather()`

to achieve that and that this function does not maintain the `grad_fn`

property in the combined data.

I have found two approaches that address this issue.

###
#1 Extend `torch.autograd.Function`

, such as:

- In github.com/Spijkervet/SimCLR: this is the implementation of the loss and this is its autograd.Function.
- In github.com/PyTorchLightning: this is the implementation of the loss and this is its autograd.Function.

```
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py
def nt_xent_loss(self, ...):
...
if self.world_size > 1:
z = torch.cat(GatherLayer.apply(z), dim=0)
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
...
```

```
# https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py
class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation."""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
(input,) = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
```

### #2 Overwrite the gathered list using the partition from the local replica

I’ve seen this approach described by **JohnGiorgi** in this comment. The idea is to use the local tensor that has the proper `grad_fn`

to overwrite the corresponding partition in the list returned by `dist.all_gather()`

.

(I am copying his pseudo-code here to make the question self-contained)

```
# From: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10
import torch
import torch.distributed as dist
# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)
# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings
# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)
# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)
```

Also, this blog post claims that when using this approach, we should re-scale the gradients by the number of GPUS (i.e., word_size) to obtain the correct gradients.

### Question

Which one should I use? Is it just a matter of style or are there meaningful differences between these approaches?