When I train my work with multinode, the code below can gather all tensors from all_gpus.
def gather_tensors(tensor):
"""
We find this function works well for single node, but not for multi-node
So we want to modify this function to gathered for gpus on same node
"""
gathered_tensors = torch.cat(torch.distributed.nn.all_gather(tensor), dim=0)
return gathered_tensors
But now, I want to gather tensors on same node and implement a version as below:
def init_group(self):
world_size = dist.get_world_size()
num_gpus_per_node = torch.cuda.device_count()
groups = []
for j in range(world_size//num_gpus_per_node):
node_ranks = [j * num_gpus_per_node + i for i in range(num_gpus_per_node)]
node_ranks = [rank for rank in node_ranks if rank < world_size]
group = dist.new_group(ranks=node_ranks)
groups.append(group)
return groups
def gather_tensors_on_same_node(tensor, group=group.WORLD):
"""
Gather tensors from all GPUs on the same node. Assume each node have 8 gpus.
"""
gathered_tensors = torch.cat(torch.distributed.nn.all_gather(tensor, group=group), dim=0)
return gathered_tensors
and I call them with
global_rank = dist.get_rank()
group = self.groups[global_rank//torch.cuda.device_count()]
text_latents = gather_tensors_on_same_node(text_latents, group=group)
image_latents = gather_tensors_on_same_node(image_latents, group=group)
But it will always report error " File “/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/distributed/nn/functional.py”, line 294, in backward
gx = torch.empty_like(grad_outputs[rank])
IndexError: tuple index out of range" and access data wrongly. Do you have any suggestions?