Torch.compile seems to hang

As a followup to Is there an equivalent of jax.lax.scan (eg in torch.func)?, I have been trying to compile my kalman filter code using torch.compile.
However it seems to hang (even though the simple example of torch.compile in the documentation works). I am running torch 2.0 on a macbook with Intel CPU.

Here is the key code snippet:

def kf_pt(params, emissions, return_covs=False, compile=False):
    F, Q, R = params['F'], params['Q'], params['R']
    def step(carry, t):
        ...
    
    if compile:
        **step = torch.compile(step)**
    num_timesteps = len(emissions)
    D = len(params['mu0'])
    ll = 0
    carry = (ll, params['mu0'], params['Sigma0'])
    for t in range(num_timesteps):
        if return_covs:
            carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
        else:
            carry, filtered_means[t] = step(carry, t)
    return ll, filtered_means, filtered_covs

Full code is at dynamax/kf-linreg-pt.ipynb at main · probml/dynamax · GitHub

When I call ll_pt, kf_means_pt, kf_covs_pt = kf_pt(param_dict_pt, Y_pt, return_covs, **compile=True**) it just hangs (for several minutes). If I interupt it I see

--------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<timed exec> in <module>

<ipython-input-42-ada65422be98> in kf_pt(params, emissions, return_covs, compile)
     39     for t in range(num_timesteps):
     40         if return_covs:
---> 41             carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
     42         else:
     43             carry, filtered_means[t] = step(carry, t)

/opt/anaconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
    207             dynamic_ctx.__enter__()
    208             try:
--> 209                 return fn(*args, **kwargs)
    210             finally:
    211                 set_eval_frame(prior)

/opt/anaconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py in catch_errors(frame, cache_size)
    335 
    336         with compile_lock:
--> 337             return callback(frame, cache_size, hooks)
    338 
    339     catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]

...

/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/graph.py in compile_to_module(self)
    573             print(code)
    574 
--> 575         mod = PyCodeCache.load(code)
    576         for name, value in self.constants.items():
    577             setattr(mod, name, value)

/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/codecache.py in load(cls, source_code)
    526                 mod.__file__ = path
    527                 mod.key = key
--> 528                 exec(code, mod.__dict__, mod.__dict__)
    529                 # another thread might set this first
    530                 cls.cache.setdefault(key, mod)

/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/torchinductor_kpmurphy/ym/cymnh75e3azdeay6wzkazx6yl3p457udrtmnkd3pyeo5vzgsls7k.py in <module>
    109 
    110 
--> 111 async_compile.wait(globals())
    112 del async_compile
    113 

/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/codecache.py in wait(self, scope)
    713                     pbar.set_postfix_str(key)
    714                 if isinstance(result, (Future, TritonFuture)):
--> 715                     scope[key] = result.result()
    716                     pbar.update(1)
    717 

/opt/anaconda3/lib/python3.8/concurrent/futures/_base.py in result(self, timeout)
    432                 return self.__get_result()
    433 
--> 434             self._condition.wait(timeout)
    435 
    436             if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

/opt/anaconda3/lib/python3.8/threading.py in wait(self, timeout)
    300         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    301             if timeout is None:
--> 302                 waiter.acquire()
    303                 gotit = True
    304             else:

KeyboardInterrupt: 

This was a bug which was fixed in the nightlies, try instead torch.compile(.., backend="aot_eager") - inductor the default backend is not supported on Mac

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

Thanks. That fixes the problem. However, I noticed by code is 5x slower after compiling.
I get the warning

2023-04-10 17:59:00,831] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: 'step' (kf_linreg_jax_vs_pt.py:171)
   reasons:  L['t'] == 63

Here is the code:

Any ideas?

Apologies for the delay but there are a few issues

  1. AOT eager won’t have substantial speedups relative to inductor
  2. Your model is dynamic so the recompilations are what is giving the appearance of a hang
  3. For some reason dynamic=True has no impact
  4. After it recompiles your model OOMs
  5. Your benchmark suite doesn’t factor in warmup
    6 torch.compile speedups will be most significant on newer GPUs like A100

For now I’m running experiments on an A10G which means I default to inductor which is the compiler with the most significant speedups

For the recompiles I ran TORCH_LOGS="recompiles" python kf_linreg_jax_vs_pt.py and it showed that the culprit was the t variable here https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py#L86 if I fix it to 1 then the hang goes away

To handle dynamic models we have an argument to torch.compile(m, dynamic=True) but for some reason I don’t see its impact on your model

Even with t fixed your model is OOM’ing, I was able to make the OOM go away with this line and I’ve opened an issue about this on github since pattern matcher is on by default OOM in fuse_attention inductor pass · Issue #99084 · pytorch/pytorch · GitHub

import torch._inductor.config
torch._inductor.config.pattern_matcher = False

As far as benchmarking is concerned you call torch.compile(m) the compilation only happens on the first inference so for benchmarks you need to remove the first inference time since the assumption is you’ll amortize it over a long enough experiment

Finally I added a torch.set_default_device('cuda') in your script which defaults all tensors to GPU on an A100 and the benchmarks are not anything to write home about quite yet but the speedups are there

sm/demos$ python kf_linreg_jax_vs_pt.py 
torch, time=0.604 compile False N 100 D 500
[2023-04-13 21:36:10,010] torch._inductor.utils: [WARNING] make_fallback(aten.cumprod): a decomposition exists, we should switch to it
torch, time=4.884 compile True N 100 D 500
torch, time=0.552 compile True N 100 D 500

Thanks. I know that jax.jit does tracing, and as long as the shapes don’t change when the args change, it doesn’t recompile the function. Here t just specifies which row to take out of a global variable X, so the shape of the variables is constant across iterations. I assumed torch.compile would do something similar to jax.

1 Like

It should have been the case as well, PR for fix is out Don't specialize when indexing by SymInt by ezyang · Pull Request #99123 · pytorch/pytorch · GitHub

And the OOM issue has a PR out too [inductor] Use FakeTensorMode() when creating patterns by jansel · Pull Request #99128 · pytorch/pytorch · GitHub