How does reduce_scatter work?

I’m trying to understand how reduce_scatter works in Pytorch. So far, I have found the following code snippet:

class AllGatherFunction(torch.autograd.Function):
    def forward(ctx, tensor: torch.Tensor, reduce_dtype: torch.dtype = torch.float32):
        ctx.reduce_dtype = reduce_dtype

        output = list(torch.empty_like(tensor) for _ in range(dist.get_world_size()))
        dist.all_gather(output, tensor)
        output =, dim=0)
        return output

    def backward(ctx, grad_output: torch.Tensor):
        grad_dtype = grad_output.dtype
        input_list = list(
        grad_input = torch.empty_like(input_list[dist.get_rank()])
        dist.reduce_scatter(grad_input, input_list)

def all_gather(tensor):
    return AllGatherFunction.apply(tensor)

which is from OpenAI’s CLIP: how to use multiple GPUs,the default is to use the first CUDA device · Issue #111 · openai/CLIP · GitHub.
I’m trying to figure out what reduce_scatter does by staring at this code snippet, and I have no idea. Could someone help me out?

These docs might be helpful as they visualize communication patterns.


Thanks! That was a good link. So to clarify, from the docs, the inX refers to the concatenation of the input tensor from each rank. (Specifically a column-wise concat, so the # of rows remains the same but the length of each row increases).

There are also some magic variables, like count and Y that I don’t quite understand.
Currently, writing up some toy code to figure this out.

Sample code for understanding:

import argparse

import torch as th
import torch.distributed as dist

if __name__ == "__main__":
	argparse = argparse.ArgumentParser()
	argparse.add_argument("rank", type=int)
	args = argparse.parse_args()

		# Testing on 2 GPUs

	rank = dist.get_rank()
	assert rank == args.rank

	device = th.device(f"cuda:{rank}")

	# For NCCL, the tensors must be on different GPUs
	tensor_len = 3
	vals = [(rank * tensor_len) + i for i in range(tensor_len)]
	tensor = th.IntTensor(vals).to(device=device)
	# tensor = th.full((3,), rank).to(device=device)
	print(f"Rank: {rank}")
	print(f"Start tensor: {tensor}")
	output = list(th.empty_like(tensor) for _ in range(dist.get_world_size()))
	# print(f"All-gather list: {output}")
	dist.all_gather(output, tensor)

	print(f"Gathered: {output}")
	# catted =
	scatter_size = 5
	aggregate = th.IntTensor(list(range(dist.get_world_size() * scatter_size))).to(device=device)
	input_list = list(aggregate.chunk(dist.get_world_size()))
	print(f"Input list: {input_list}")
	rscatter_input = th.empty_like(input_list[dist.get_rank()])
	dist.reduce_scatter(rscatter_input, input_list)
	print(f"Reduce Scatter output: {rscatter_input}")

	aggregate = th.IntTensor([dist.get_rank() * scatter_size + i for i in range(dist.get_world_size() * scatter_size)]).to(device=device)
	input_list = list(aggregate.chunk(dist.get_world_size()))
	print(f"Input list 2: {input_list}")
	rscatter_input = th.empty_like(input_list[dist.get_rank()])
	dist.reduce_scatter(rscatter_input, input_list)
	print(f"Reduce Scatter output 2: {rscatter_input}")