Hello
I am currently training a llama1b with Torchtitan and hoping to capture the training graph(s) with the various collectives when I work with different parallelism combinations. Is there a way to capture those? I tried to intercept a training step and use the aot_module in functorch.compile to capture it, but I think the Fake Tensor propagation is not working with it. Moreover, is there a better way to capture the graph and all the collectives that get inserted in it as a result of the various degrees of parallelisms?
Does torchtitan/torchtitan/experiments/simple_fsdp at main · pytorch/torchtitan · GitHub provide what you need?
You get the full graph with all collectives in it, although there is no communication optimizations for FSDP in the graph you get, which means comms like all-gathers and reduce-scatters for FSDP are exposed.
Thanks @tianyu
This looks helpful. Will this work with other parallelism strategies, such as TP, CP, and PP?
Also, when I tried to run this simply with TORCH_TRACE="outputs/compile_trace" CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_1B.toml" NGPU=4 ./run_train.sh --model.name llama3_simple_fsdp --training.compile
I get the following error: `torch._dynamo.exc.Unsupported: TypeError when making fake tensor call
Explanation:
Developer debug context: TypeError <function TensorVariable.method_redistribute.<locals>.redistribute_fn_with_prim_types at 0x7f3cb7d1f9c0>: DTensor.redistribute() got an unexpected keyword argument 'forward_dtype'`
from user code:
File "/home/amodab01/kyojin/torchtitan/models/llama3/model.py", line 507, in forward
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 192, in forward
self.weight,
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/utils/parametrize.py", line 406, in get_parametrized
return parametrization()
File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/nn/utils/parametrize.py", line 299, in forward
x = self[0](self.original)
File "/home/amodab01/kyojin/torchtitan/experiments/simple_fsdp/simple_fsdp.py", line 213, in forward
output = self.replicate_compute(x)
File "/home/amodab01/kyojin/torchtitan/experiments/simple_fsdp/simple_fsdp.py", line 189, in replicate_compute
output = x.redistribute(
I can provide the log files for more detail
Yes it should work with all other parallelisms, except for HSDP+TP which we are still working on.
For the error you saw, it’s because SimpleFSDP + mixed precision training relies on a feature we recently added. So you may need to use pytorch nightly, instead of a stable release.
Thanks @tianyu
A quick naive follow up, does the graph automatically get captured? I tried turning on a few torch inductor configs but wasn’t able to find it. I tried these few configs
import torch._inductor.config
torch._inductor.config.trace.debug_log = True
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.graph_diagram = True
torch._inductor.config.trace.draw_orig_fx_graph = True
Also, is there any way I can intercept the NCCL calls in the graph?
You could add a TORCH_COMPILE_DEBUG=1
command when running the code, which will dump generated triton code in your local torch_compile_debug folder. Inside torch_compile_debug/*/torchinductor/*/fx_graph_runnable.py
, you will find communication operators (e.g., all_gather_into_tensor, reduce_scatter_tensor) along with other compute operations.