As a followup to Is there an equivalent of jax.lax.scan (eg in torch.func)?, I have been trying to compile my kalman filter code using torch.compile.
However it seems to hang (even though the simple example of torch.compile in the documentation works). I am running torch 2.0 on a macbook with Intel CPU.
Here is the key code snippet:
def kf_pt(params, emissions, return_covs=False, compile=False):
F, Q, R = params['F'], params['Q'], params['R']
def step(carry, t):
...
if compile:
**step = torch.compile(step)**
num_timesteps = len(emissions)
D = len(params['mu0'])
ll = 0
carry = (ll, params['mu0'], params['Sigma0'])
for t in range(num_timesteps):
if return_covs:
carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
else:
carry, filtered_means[t] = step(carry, t)
return ll, filtered_means, filtered_covs
Full code is at dynamax/kf-linreg-pt.ipynb at main · probml/dynamax · GitHub
When I call ll_pt, kf_means_pt, kf_covs_pt = kf_pt(param_dict_pt, Y_pt, return_covs, **compile=True**)
it just hangs (for several minutes). If I interupt it I see
--------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<timed exec> in <module>
<ipython-input-42-ada65422be98> in kf_pt(params, emissions, return_covs, compile)
39 for t in range(num_timesteps):
40 if return_covs:
---> 41 carry, (filtered_means[t], filtered_covs[t]) = step(carry, t)
42 else:
43 carry, filtered_means[t] = step(carry, t)
/opt/anaconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
207 dynamic_ctx.__enter__()
208 try:
--> 209 return fn(*args, **kwargs)
210 finally:
211 set_eval_frame(prior)
/opt/anaconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py in catch_errors(frame, cache_size)
335
336 with compile_lock:
--> 337 return callback(frame, cache_size, hooks)
338
339 catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
...
/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/graph.py in compile_to_module(self)
573 print(code)
574
--> 575 mod = PyCodeCache.load(code)
576 for name, value in self.constants.items():
577 setattr(mod, name, value)
/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/codecache.py in load(cls, source_code)
526 mod.__file__ = path
527 mod.key = key
--> 528 exec(code, mod.__dict__, mod.__dict__)
529 # another thread might set this first
530 cls.cache.setdefault(key, mod)
/var/folders/mn/vt7cgfsx6zs9vblhvbbk7pf8003xtr/T/torchinductor_kpmurphy/ym/cymnh75e3azdeay6wzkazx6yl3p457udrtmnkd3pyeo5vzgsls7k.py in <module>
109
110
--> 111 async_compile.wait(globals())
112 del async_compile
113
/opt/anaconda3/lib/python3.8/site-packages/torch/_inductor/codecache.py in wait(self, scope)
713 pbar.set_postfix_str(key)
714 if isinstance(result, (Future, TritonFuture)):
--> 715 scope[key] = result.result()
716 pbar.update(1)
717
/opt/anaconda3/lib/python3.8/concurrent/futures/_base.py in result(self, timeout)
432 return self.__get_result()
433
--> 434 self._condition.wait(timeout)
435
436 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
/opt/anaconda3/lib/python3.8/threading.py in wait(self, timeout)
300 try: # restore state no matter what (e.g., KeyboardInterrupt)
301 if timeout is None:
--> 302 waiter.acquire()
303 gotit = True
304 else:
KeyboardInterrupt: