I have multiple questions about how to use torch.compile properly. I found the tutorial/documentation lackluster. I will try and list all of them down including those I found answered in this forum but are missing from the tutorial, for future readers.
-
How to use torch.compile with a non-trivial nn.Module? This post on stackoverflow perfectly sums up my question. Namely, should torch.compile be called manually on every sub-module I wish to compile, or does it handle that automatically? More generally, how does torch.compile behave when used on a nn.Module? Does it compile only the forward method? In my case, similar to OP in that stackoverflow post, I observe zero performance difference when using torch.compile, even though my model has 27M parameters and I am using one of the recommended GPUs in the tutorial (V100). I also note torch.compile exits immediately, as if it were doing nothing.
-
Should I cast to DDP before or after torch.compile? In the DistributedDataParallel tutorial (can’t post another link due to new forum user limitation), torch.compile is called after on the DDP model. But in this post, it is recommended to call torch.compile before. So what is the right answer?
-
Should I move the model to device before or after torch.compile? In that same post, it is recommended to cast to CUDA before calling torch.compile. But this is not mentioned in the tutorial even though this seems like an pretty essential consideration.