Functional collectives

Hi, I recently noticed this line in the all-reduce and broadcast collectives
Do we need to do these redundant copies?
This is making TP with DTensors slower than TP with normal tensors.

at::Tensor all_reduce(
    const at::Tensor& input,
    std::string reduce_op,
    std::string group_name) {
  auto output = input.clone(at::MemoryFormat::Contiguous);
  return all_reduce_(output, std::move(reduce_op), std::move(group_name));
}

every all reduce call is causing this tensor copy as can be seen in the trace

I think functional all-reduce is always out-of-place, so it does this clone. I think torch.compile would help re-inplace it. cc: @wanchaol

1 Like

Yes @agu is correct, functional collectives by design is out of place so it would always allocating the output tensor first, if turning on torch.compile it would re-inplace it.

@mayank31398 I wonder how much slowness you observed?

For training we usually use SequenceParallel by default, where for SequenceParallel the allocation would happen anyways and the input/output shape of the allgather/reduce_scatter are different. For the case where there’s no SequenceParallel, iirc we benchmarked the e2e training on llama models and it does not show observable slowness compare to a TP that works on normal tensors.

@wanchaol there is a reasonable difference when not using compile.
Is there not a plan of moving to in-place?
The copy is redundant and honestly compile doesn’t work with everything so there are a lot of problems with using out-of-place here

continuing discussions in Remove redundant copy in functional collectives · Issue #134388 · pytorch/pytorch · GitHub

@mayank31398 Can we do TP with normal tensors in PyTorch? I thought TP can only be done with DTensors in PyTorch?

yes ofc, why do you think we can’t do TP without DTensors?
following code does TP without sequence parallel (SP):

class ReduceFromTensorParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        torch.distributed.all_reduce(x)
        return x

    @staticmethod
    def backward(ctx: Any, x_grad: torch.Tensor) -> Any:
        return x_grad