I’m trying to understand how reduce_scatter
works in Pytorch. So far, I have found the following code snippet:
class AllGatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, reduce_dtype: torch.dtype = torch.float32):
ctx.reduce_dtype = reduce_dtype
output = list(torch.empty_like(tensor) for _ in range(dist.get_world_size()))
dist.all_gather(output, tensor)
output = torch.cat(output, dim=0)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
grad_dtype = grad_output.dtype
input_list = list(grad_output.to(ctx.reduce_dtype).chunk(dist.get_world_size()))
grad_input = torch.empty_like(input_list[dist.get_rank()])
dist.reduce_scatter(grad_input, input_list)
return grad_input.to(grad_dtype)
def all_gather(tensor):
return AllGatherFunction.apply(tensor)
which is from OpenAI’s CLIP: how to use multiple GPUs,the default is to use the first CUDA device · Issue #111 · openai/CLIP · GitHub.
I’m trying to figure out what reduce_scatter
does by staring at this code snippet, and I have no idea. Could someone help me out?
These docs might be helpful as they visualize communication patterns.
2 Likes
Thanks! That was a good link. So to clarify, from the docs, the inX
refers to the concatenation of the input tensor from each rank. (Specifically a column-wise concat, so the # of rows remains the same but the length of each row increases).
There are also some magic variables, like count
and Y
that I don’t quite understand.
Currently, writing up some toy code to figure this out.
Sample code for understanding:
import argparse
import torch as th
import torch.distributed as dist
if __name__ == "__main__":
argparse = argparse.ArgumentParser()
argparse.add_argument("rank", type=int)
args = argparse.parse_args()
dist.init_process_group(
backend="nccl",
init_method="tcp://localhost:8001",
rank=args.rank,
# Testing on 2 GPUs
world_size=2
)
rank = dist.get_rank()
assert rank == args.rank
device = th.device(f"cuda:{rank}")
# For NCCL, the tensors must be on different GPUs
tensor_len = 3
vals = [(rank * tensor_len) + i for i in range(tensor_len)]
tensor = th.IntTensor(vals).to(device=device)
# tensor = th.full((3,), rank).to(device=device)
print(f"Rank: {rank}")
print(f"Start tensor: {tensor}")
output = list(th.empty_like(tensor) for _ in range(dist.get_world_size()))
# print(f"All-gather list: {output}")
dist.all_gather(output, tensor)
print(f"Gathered: {output}")
# catted = th.cat(output)
scatter_size = 5
aggregate = th.IntTensor(list(range(dist.get_world_size() * scatter_size))).to(device=device)
input_list = list(aggregate.chunk(dist.get_world_size()))
print(f"Input list: {input_list}")
rscatter_input = th.empty_like(input_list[dist.get_rank()])
dist.reduce_scatter(rscatter_input, input_list)
print(f"Reduce Scatter output: {rscatter_input}")
aggregate = th.IntTensor([dist.get_rank() * scatter_size + i for i in range(dist.get_world_size() * scatter_size)]).to(device=device)
input_list = list(aggregate.chunk(dist.get_world_size()))
print(f"Input list 2: {input_list}")
rscatter_input = th.empty_like(input_list[dist.get_rank()])
dist.reduce_scatter(rscatter_input, input_list)
print(f"Reduce Scatter output 2: {rscatter_input}")