How to call torch.distributed.nn.all_gather on each node independently?

When I train my work with multinode, the code below can gather all tensors from all_gpus.

def gather_tensors(tensor):
    """
    We find this function works well for single node, but not for multi-node
    So we want to modify this function to gathered for gpus on same node
    """
    gathered_tensors = torch.cat(torch.distributed.nn.all_gather(tensor), dim=0)
    return gathered_tensors

But now, I want to gather tensors on same node and implement a version as below:

    def init_group(self):
        world_size = dist.get_world_size()
        num_gpus_per_node = torch.cuda.device_count()
        groups = []
        for j in range(world_size//num_gpus_per_node):
            node_ranks = [j * num_gpus_per_node + i for i in range(num_gpus_per_node)]
            node_ranks = [rank for rank in node_ranks if rank < world_size]
            group = dist.new_group(ranks=node_ranks)
            groups.append(group)
        return groups

def gather_tensors_on_same_node(tensor, group=group.WORLD):
    """
    Gather tensors from all GPUs on the same node. Assume each node have 8 gpus.
    """
    gathered_tensors = torch.cat(torch.distributed.nn.all_gather(tensor, group=group), dim=0)
    return gathered_tensors

and I call them with

        global_rank = dist.get_rank()
        group = self.groups[global_rank//torch.cuda.device_count()]
        text_latents = gather_tensors_on_same_node(text_latents, group=group)
        image_latents = gather_tensors_on_same_node(image_latents, group=group)

But it will always report error " File “/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/distributed/nn/functional.py”, line 294, in backward
gx = torch.empty_like(grad_outputs[rank])
IndexError: tuple index out of range" and access data wrongly. Do you have any suggestions?

3 Likes

same problem here… do you guys have any all_gather alternatives which retain computation graph and can compute grad?

This also gets me in stuck