I want to torch compile the same model several times, optimized for different shapes, but pointing to the same large model weights. Allowing the same model to be dynamic over the shape ranges resulted in too large a slow down.
The ID differs between compiled model variables, but the fact that models 3, 4 and onwards suddenly compile (on first pass) quickly worries me. Does something optimize the graph generation after a few times, or are they all pointing to the same graph, and the reason #2 is slow, is due to it becoming dynamic with the different input size?
Setting dynamic=False in torch.compile() makes every JIT take the full, slow, time.
(1) when you call m = torch.compile(m), you are not actually generating a “copy” of the model (so you don’t have to worry about your weights not being shared across compiled regions). You are effectively just annotating the torch.nn.Module.__call__ method on your model as being marked for compilation.
(2) You don’t actually need to call torch.compile multiple times. If you just do this:
(1) the first time you invoke your compiled model, we will perform compilation to get a compiled artifact, and we will also generate a bunch of guards that tell us when this compiled artifact is safe to re-use (for example, we will guard on stuff like the shape of your input, its device/dtype, whether it requires_grad, etc)
(2) when you run your compiled model on another input, then we will run the series of guards that we generated, and if they all pass we’ll re-use one of the previously generated compiled artifacts. If any guards fail, we will have to recompile (causing the slowness you mentioned above).
If you want to see the source of recompilations, you can run with TORCH_LOGS="recompiles" to see why