Indicating parameters as static to `torch.compile`

In Jax, when I JIT-compile my functions, I indicate which parameters I want to be static with the static_argnums argument to jax.jit. When I do this, Jax knows to compile one version of this function for each value of these parameters, and to create compiled functions that accept variables for each of the other parameters.

In my case I’m working on fluid simulations, where the values of my fields will change, but for instance I won’t ever have any changes in my boundary conditions. In this case I have classes containing field data, and domain data, where the domain is completely static between runs.

@torch.compile
def simulate(fluid:​ Field, domain:​ Domain):
    ...

Here I want to recompile when any of the data in domain changes, but not when the data in Field changes. For instance, I have used Equinox in Jax to define data structures with static parameters anywhere in the tree, like the static=True flag in eqx.field here. In Jax this is easy, and I can control in what cases I need to guard and recompile vs when I can use a function I’ve previously compiled.

Is there a way accomplish this in PyTorch 2 with torch.compile? If there are any resources or if anyone has any knowledge on how these compile guards are decided, I would appreciate to hear about them. Thank you!

1 Like