Torch.compile raises error when compiled function calls other functions

Hello!

I am trying to use torch.compile in some non-standard contexts. Specifically, I am trying to compile small probabilistic models that do not make use of any nn functionality at all, but have tons of overhead because there are many small ops that must happen.

I am having a real rough time getting torch.compile to work in any context. Whenever there is an issue, I’m faced with this massive message that doesn’t even indicate the line that raised the issue. I have to go in and comment out code until it runs to find the line:

  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 401, in compile_fx
    return compile_fx(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
    return inner_compile(
  File "/users/jmschr/anaconda3/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 177, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 586, in compile_to_fn
    return self.compile_to_module().call
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 575, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 528, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_jmschr/44/c44qbatu2ybvfqippj5t6gqx6v5ikyqnjrvfhtiazn6s26vfcqfy.py", line 106, in <module>
    async_compile.wait(globals())
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 715, in wait
    scope[key] = result.result()
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 573, in result
    self.future.result()
  File "/users/jmschr/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/users/jmschr/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/users/jmschr/github/torchegranate/test_silent_hmm.py", line 70, in <module>
    print(f(X))
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

More specifically, I’ve found that when I have a function that calls another function, torch.compile seems to fail, even if that second function can be compiled. Here’s a minimally reproducing script where I define a simple probability distribution and a mixture and show that the probability distributions log_prob method can be compiled but that when I try to call it through the mixture model, it fails.

import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import time
import torch
torch.set_float32_matmul_precision('high')

import torch._dynamo
torch._dynamo.reset()
torch._dynamo.config.verbose = True

class Expon(torch.nn.Module):
	def __init__(self, scales):
		super().__init__()

		self.scales = scales
		self._log_scales = torch.log(scales)

	def log_prob(self, X):
		return torch.sum(-self._log_scales - (1. / self.scales) * X, dim=1)


class Mixture(torch.nn.Module):
	def __init__(self, dists):
		self.dists = dists
		self.k = len(dists)

	def emissions(self, X):
		e = torch.empty(X.shape[0], self.k, device='cuda:0')
		
		for i, d in enumerate(self.dists):
			e[:, i] = d.log_prob(X)

		return e

n, d = 8, 6
X = torch.randn(n, d).cuda()

mu1 = torch.exp(torch.randn(d)).cuda()
d1 = Expon(mu1)

mu2 = torch.exp(torch.randn(d)).cuda()
d2 = Expon(mu2)

print(d1.log_prob(X))
f = torch.compile(d1.log_prob, mode='reduce-overhead', fullgraph=True)
print(f(X))

print("\n\n")

model = Mixture([d1, d2])
print(model.emissions(X))

f = torch.compile(model.emissions, mode='reduce-overhead', fullgraph=True)
print(f(X))

The output from this is:

tensor([ 0.7098, 11.4090, 16.9829,  2.7388,  2.8006, -8.2336, 21.9166, -2.0553],
       device='cuda:0')
tensor([ 0.7098, 11.4090, 16.9830,  2.7388,  2.8006, -8.2336, 21.9166, -2.0553],
       device='cuda:0')



tensor([[ 0.7098,  3.1660],
        [11.4090,  2.1182],
        [16.9829,  7.2627],
        [ 2.7388,  0.8454],
        [ 2.8006,  3.6874],
        [-8.2336, -2.3724],
        [21.9166, 13.9888],
        [-2.0553,  3.9145]], device='cuda:0')
Traceback (most recent call last):
  File "/users/jmschr/github/torchegranate/test_silent_hmm.py", line 78, in <module>
    print(f(X))
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 364, in _compile
    check_fn = CheckFunctionManager(
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 547, in __init__
    guard.create(local_builder, global_builder)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_guards.py", line 163, in create
    return self.create_fn(self.source.select(local_builder, global_builder), self)
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/guards.py", line 299, in NN_MODULE
    unimplemented(f"Guard setup for uninitialized class {type(val)}")
  File "/users/jmschr/anaconda3/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 71, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Guard setup for uninitialized class <class '__main__.Mixture'>


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

tl;dr, when I try to compile the log_prob method in the probability distribution it succeeds, but when I try to compile the emissions method that calls the log_prob method it fails. Any tips on how to get this to succeed? Is this a lost cause for now? Thanks!!

I am facing the sam issue with “raise Unsupported(msg)”. Have you gotten an idea how to find the unsupported ops ?