Hello!
I am trying to use torch.compile
in some non-standard contexts. Specifically, I am trying to compile small probabilistic models that do not make use of any nn functionality at all, but have tons of overhead because there are many small ops that must happen.
I am having a real rough time getting torch.compile
to work in any context. Whenever there is an issue, I’m faced with this massive message that doesn’t even indicate the line that raised the issue. I have to go in and comment out code until it runs to find the line:
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
return compile_fx(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/users/jmschr/anaconda3/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
return self.compile_to_module().call
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
mod = PyCodeCache.load(code)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 528, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_jmschr/44/c44qbatu2ybvfqippj5t6gqx6v5ikyqnjrvfhtiazn6s26vfcqfy.py", line 106, in <module>
async_compile.wait(globals())
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 715, in wait
scope[key] = result.result()
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 573, in result
self.future.result()
File "/users/jmschr/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
return self.__get_result()
File "/users/jmschr/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/users/jmschr/github/torchegranate/test_silent_hmm.py", line 70, in <module>
print(f(X))
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
More specifically, I’ve found that when I have a function that calls another function, torch.compile
seems to fail, even if that second function can be compiled. Here’s a minimally reproducing script where I define a simple probability distribution and a mixture and show that the probability distributions log_prob
method can be compiled but that when I try to call it through the mixture model, it fails.
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import torch
torch.set_float32_matmul_precision('high')
import torch._dynamo
torch._dynamo.reset()
torch._dynamo.config.verbose = True
class Expon(torch.nn.Module):
def __init__(self, scales):
super().__init__()
self.scales = scales
self._log_scales = torch.log(scales)
def log_prob(self, X):
return torch.sum(-self._log_scales - (1. / self.scales) * X, dim=1)
class Mixture(torch.nn.Module):
def __init__(self, dists):
self.dists = dists
self.k = len(dists)
def emissions(self, X):
e = torch.empty(X.shape[0], self.k, device='cuda:0')
for i, d in enumerate(self.dists):
e[:, i] = d.log_prob(X)
return e
n, d = 8, 6
X = torch.randn(n, d).cuda()
mu1 = torch.exp(torch.randn(d)).cuda()
d1 = Expon(mu1)
mu2 = torch.exp(torch.randn(d)).cuda()
d2 = Expon(mu2)
print(d1.log_prob(X))
f = torch.compile(d1.log_prob, mode='reduce-overhead', fullgraph=True)
print(f(X))
print("\n\n")
model = Mixture([d1, d2])
print(model.emissions(X))
f = torch.compile(model.emissions, mode='reduce-overhead', fullgraph=True)
print(f(X))
The output from this is:
tensor([ 0.7098, 11.4090, 16.9829, 2.7388, 2.8006, -8.2336, 21.9166, -2.0553],
device='cuda:0')
tensor([ 0.7098, 11.4090, 16.9830, 2.7388, 2.8006, -8.2336, 21.9166, -2.0553],
device='cuda:0')
tensor([[ 0.7098, 3.1660],
[11.4090, 2.1182],
[16.9829, 7.2627],
[ 2.7388, 0.8454],
[ 2.8006, 3.6874],
[-8.2336, -2.3724],
[21.9166, 13.9888],
[-2.0553, 3.9145]], device='cuda:0')
Traceback (most recent call last):
File "/users/jmschr/github/torchegranate/test_silent_hmm.py", line 78, in <module>
print(f(X))
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 364, in _compile
check_fn = CheckFunctionManager(
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 547, in __init__
guard.create(local_builder, global_builder)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_guards.py", line 163, in create
return self.create_fn(self.source.select(local_builder, global_builder), self)
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 299, in NN_MODULE
unimplemented(f"Guard setup for uninitialized class {type(val)}")
File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Guard setup for uninitialized class <class '__main__.Mixture'>
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
tl;dr, when I try to compile the log_prob
method in the probability distribution it succeeds, but when I try to compile the emissions
method that calls the log_prob
method it fails. Any tips on how to get this to succeed? Is this a lost cause for now? Thanks!!