Is there a way to use torch.compile with a model that uses a custom, autotuned Triton kernel in the middle? Concretely, I want to use this MoE layer in Languini Kitchen, but I get a gigantic unparsable error message from below. Do you have any suggestions on how to even start debugging this?
[2023-11-30 14:40:53,370] [7/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-11-30 14:40:53,370] [7/0] torch._dynamo.variables.higher_order_ops: [ERROR] call_function partial in skip_files /usr/lib/python3.10/functools.py
Traceback (most recent call last):
File "/home/robert/languini-kitchen/languini/common_lib/throughput.py", line 132, in <module>
main()
File "/home/robert/languini-kitchen/languini/common_lib/throughput.py", line 82, in main
result = throughput_utils.throughput_test(config=config, model=model)
File "/home/robert/languini-kitchen/languini/common_lib/throughput_utils.py", line 79, in throughput_test
step(inputs, targets, state)
File "/home/robert/languini-kitchen/languini/common_lib/throughput_utils.py", line 65, in step
logits, _ = model(inputs, state)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/languini-kitchen/languini/projects/moe/model.py", line 91, in forward
x = layer(x, log=log)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/languini-kitchen/languini/projects/moe/lib.py", line 190, in forward
mlp_x = self.mlp(self.ln2(x, log=log), log=log)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/languini-kitchen/languini/projects/moe/lib.py", line 159, in forward
x, reg_loss = self.moe(x)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/moe_layer_simple.py", line 166, in forward
scores = self.compute_scores(input, sel_indices)
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/moe_layer_simple.py", line 97, in compute_scores
scores = cvmm(input, index, self.keys)
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 473, in cvmm
return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index, sel.reduction_weight)
File "/home/robert/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 409, in forward
res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 351, in cvmm_triton
cvmm_kernel[grid](
File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 351, in <resume in cvmm_triton>
cvmm_kernel[grid](
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 99, in run
bench_start = time.time()
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <resume in run>
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 74, in _bench
current = dict(meta, **config.kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in <resume in _bench>
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/home/robert/.local/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
fn()
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "<string>", line 4, in cvmm_kernel
File "<string>", line 4, in <resume in cvmm_kernel>
File "<string>", line 4, in <resume in cvmm_kernel>
File "<string>", line 4, in <resume in cvmm_kernel>
File "<string>", line 4, in <resume in cvmm_kernel>
File "<string>", line 4, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 6, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 20, in <resume in cvmm_kernel>
File "<string>", line 21, in <resume in cvmm_kernel>
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 173, in _pinned_memory_of
if hasattr(arg, "is_pinned"):
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
super().run()
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 953, in call_isinstance
arg_type = arg.python_type()
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 221, in python_type
raise NotImplementedError(f"{self} has no type")
torch._dynamo.exc.InternalTorchDynamoError: GetAttrVariable(TensorVariable(), is_pinned) has no type
from user code:
File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in <resume in _pinned_memory_of>
if isinstance(arg.is_pinned, Callable):
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True