_functional_collectives used in torch.compile

  1. I’m confused why allreduce is not in traceable_collective_remaps. Because I found that since reduce_scatter_tensor is in traceable_collective_remaps, I can use torch.distributed.distributed_c10d.reduce_scatter_tensor in the script, torch.compile works. torch.distributed.all_reduce cannot be used directly in this way, but is it planned to support torch.distributed.all_reduce in the future?
import torch
import os
from torch.fx.experimental.proxy_tensor import make_fx
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.multiprocessing as mp
from torch._functorch.aot_autograd import aot_module_simplified
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

def toy_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print(gm.graph)
    return gm

def toy_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return aot_module_simplified(gm, example_inputs,fw_compiler=toy_compiler)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, output):
        # torch.distributed.all_reduce(x)   # it can not work!!!
        from torch.distributed.distributed_c10d import _world
        dist.reduce_scatter_tensor(output, x,  group=_world.default_pg)  # it can work
        return x

def example(rank, world_size):
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    mod = MyModule()
    opt_mod = torch.compile(mod, dynamic=True, fullgraph=True, backend=toy_backend)
    xx = torch.tensor([0, 1, 2, 3], dtype=torch.int32)
    output1 = torch.empty([2], dtype=torch.int32)
    out = opt_mod(xx, output1)
    print(out)


def main():
    world_size = 2
    mp.spawn(example,
             args=(world_size, ),
             nprocs=world_size,
             join=True)


if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29516"
    main()

I think the reason why reduce_scatter_tensor works is that it’s in the map

from torch.distributed.distributed_c10d import (
    all_gather_into_tensor as legacy_allgather,
    reduce_scatter_tensor as legacy_reducescatter,
)

# This dict should contain sets of functions that dynamo is allowed to remap.
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
traceable_collective_remaps = {
    legacy_allgather: all_gather_tensor_inplace,
    legacy_reducescatter: reduce_scatter_tensor_inplace,
}

  1. In the preceding case, reduce_scatter_tensor must use the group input parameter, which is different from the eager mode. Scripts that will support eager mode in the future can be used in torch.compile? Or we have to make the necessary changes, like
dist.reduce_scatter_tensor(output, x, async_op=False)
------> change to
from torch.distributed.distributed_c10d import _world
dist.reduce_scatter_tensor(output, x,  group=_world.default_pg)

(1)
I think allgather/reducescatter were added to the map precisely because they were needed to unblock the effort of compiling/tracing FSDP code. It should be possible to also support allreduce.

Caveat: we can’t automatically map/trace the versions of c10d ops where the ‘async=True’ flag is set. The whole premise of this ‘auto rewrite’ is that it converts an op into our functional equivalent. It can do this for the synchronous version of the mutating op by calling the functional op first and then copying the output tensor from the functional op back into the storage of the input tensor. For the async version, dynamo would have to reason about the ‘work object’ that is returned from the collective, and it does not know how to do that.

(2) IIUC your script omits the group argument but that means it uses the default group, which is probably something we could ‘infer’ inside dynamo.

Ill create a task for these changes.