Capture training graph with collectives via TorchTitan

Hello
I am currently training a llama1b with Torchtitan and hoping to capture the training graph(s) with the various collectives when I work with different parallelism combinations. Is there a way to capture those? I tried to intercept a training step and use the aot_module in functorch.compile to capture it, but I think the Fake Tensor propagation is not working with it. Moreover, is there a better way to capture the graph and all the collectives that get inserted in it as a result of the various degrees of parallelisms?

Does torchtitan/torchtitan/experiments/simple_fsdp at main · pytorch/torchtitan · GitHub provide what you need?

You get the full graph with all collectives in it, although there is no communication optimizations for FSDP in the graph you get, which means comms like all-gathers and reduce-scatters for FSDP are exposed.