How to distributed (cross-)correlation?

Hello,

I have a batch of tensor size (Batch, Patch, Dim) then after 1D batchnorm the tensor shape is (B, D).
These tensor are in distributed mode (multi GPUs) and I’d like to compute the correlation of each tensor[i] to other tensors in the batch.
The result correlation matrix should have size (B, B), so I will take the max-min value to decide which tensor would correlate the most or the least to which.
Would the code be like this?

# https://github.com/facebookresearch/moco-v3/blob/main/moco/builder.py#L126
# concat all_gather tensors (B, D) from GPUs
@torch.no_grad()
def concat_all_gather(gpus_tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(gpus_tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, gpus_tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

# gather tensor and compute correlation 
tensors = concat_all_gather(gpus_tensor)
corr = tensors @ tensors.T

Another problem is after compute correlation, I have tensor[i] and its least correlated tensor. These two tensors might in different GPUs, and I need to get the tensor before 1Dbatchnorm (j, Patch, Dim).
One possible solution is to gather tensor before 1Dbatchnorm. Is it correct? and Is there other way to do it?

Thank you.

Hey @tsly123, the allgather part looks correct to me, except that you can use torch.empty_like() instead of torch.ones_like(), which should be a little bit faster as it does not need to set all values to 1.

Another problem is after compute correlation, I have tensor[i] and its least correlated tensor. These two tensors might in different GPUs, and I need to get the tensor before 1Dbatchnorm (j, Patch, Dim).

After concat_all_gather, I assume tensors already contains values from all processes/device. Is it possible to just reuse that instead of doing another gather?