Thanks @marksaroufim
Another thing I observed is when I add a Conv layer to this model, there’s a compile error irrespective of whether torch._inductor.config.trace.enabled
is True or False.
Error:
Traceback (most recent call last):
File "foo.py", line 28, in <module>
out = m(input_tensor)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
return fn(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 979, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 820, in _convert_frame
result = inner_convert(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
return _compile(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_utils_internal.py", line 70, in wrapper_function
return function(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 701, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 568, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
transformations(instructions, code_options)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 173, in _fn
return fn(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 515, in transform
tracer.run()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
super().run()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
while self.step():
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
return handle_graph_break(self, inst, speculation.reason)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 561, in handle_graph_break
self.output.compile_subgraph(self, reason=reason)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1103, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1295, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1386, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1367, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/__init__.py", line 1745, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1454, in compile_fx
return aot_autograd(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 65, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 958, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 685, in create_aot_dispatcher_function
compiled_fn = compiler_fn(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 470, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 447, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1358, in fw_compiler_base
return inner_compile(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/debug.py", line 304, in inner
return fn(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 483, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 779, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1692, in compile_to_fn
return self.compile_to_module().call
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1635, in compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1587, in codegen
self.scheduler = Scheduler(self.buffers)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1353, in __init__
self.fuse_nodes()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1743, in fuse_nodes
self.fuse_nodes_once()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1976, in fuse_nodes_once
if not self.speedup_by_fusion(node1, node2):
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1881, in speedup_by_fusion
choice_timings = multi_node.choice_timings
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3677, in choice_timings
self._choice_timings = self._choice_timings_fn()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1110, in get_timings
timings = do_autotuning(precompile_fn)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1079, in do_autotuning
precompile_fn()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1047, in wait_on_futures
next(iterator)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 621, in result_iterator
yield fs.pop().result(end_time - time.monotonic())
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1033, in <lambda>
lambda c: precompile_with_captured_stdout(c),
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1029, in precompile_with_captured_stdout
return choice.precompile()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 779, in precompile
self.bmreq.precompile()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/autotune_process.py", line 632, in precompile
getattr(mod, self.kernel_name).precompile()
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 197, in precompile
compiled_binary, launcher = self._precompile_config(
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 380, in _precompile_config
binary = triton.compile(*compile_args, **compile_kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/compiler/compiler.py", line 273, in compile
next_module = compile_ir(module, metadata)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 285, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 198, in make_llir
pm.run(mod)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
IndexError: map::at
Repro:
import torch
import torch._inductor.config
# torch._inductor.config.trace.enabled = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
torch._inductor.config.max_autotune = True
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.linear = torch.nn.Linear(262144, 100)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.relu(self.linear(x))
m = ToyModel().to(device="cuda:0")
m = torch.compile(m)
input_tensor = torch.randn(32,3,64,64).to(device="cuda:0")
out = m(input_tensor)
I turned on trace.enabled to generate a debug.log, some of its contents are as follows:
[triton_heuristics.py:382 ERROR] Triton compilation failed: Placeholder.DESCRIPTIVE_NAME
def triton_convolution(arg_X, arg_W, out_ptr0):
KERNEL_H : tl.constexpr = 3
KERNEL_W : tl.constexpr = 3
STRIDE_H : tl.constexpr = 1
STRIDE_W : tl.constexpr = 1
PADDING_H : tl.constexpr = 1
PADDING_W : tl.constexpr = 1
GROUPS : tl.constexpr = 1
UNROLL : tl.constexpr = False
ALLOW_TF32 : tl.constexpr = True
BLOCK_M : tl.constexpr = 1024
BLOCK_N : tl.constexpr = 16
BLOCK_K : tl.constexpr = 16
X = arg_X
W = arg_W
# Tensor dimensions
BATCH = 32
IN_C = 3
IN_H = 64
IN_W = 64
OUT_C = 64
OUT_H = 64
OUT_W = 64
# Strides:
stride_xn = 12288
stride_xc = 4096
stride_xh = 64
stride_xw = 1
stride_wc_out = 27
stride_wc_in = 9
stride_wh = 3
stride_ww = 1
nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
idx_y_w = nhw % OUT_W
nh = nhw // OUT_W
idx_y_h = nh % OUT_H
idx_n = nh // OUT_H
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
group = 0
GROUP_IN_C = IN_C
GROUP_OUT_C = OUT_C
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
w_base = (
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Could be simplified, but slightly slower:
# for i in range(KERNEL_H):
# for j in range(KERNEL_W):
# for k in range(0, GROUP_IN_C, BLOCK_K):
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
k = (ijk % BLOCK_K_COUNT) * BLOCK_K
ij = ijk // BLOCK_K_COUNT
i = ij // KERNEL_W
j = ij % KERNEL_W
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
idx_x_c = tl.arange(0, BLOCK_K) + k
x_ptrs = x_base + (
(idx_x_h * stride_xh)[:, None]
+ (idx_x_w * stride_xw)[:, None]
+ (idx_x_c * stride_xc)[None, :]
)
mask_x = (
(idx_n < BATCH)[:, None]
& (idx_x_h >= 0)[:, None]
& (idx_x_h < IN_H)[:, None]
& (idx_x_w >= 0)[:, None]
& (idx_x_w < IN_W)[:, None]
& (idx_x_c < GROUP_IN_C)[None, :]
)
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
w_ptrs = w_base + (
(idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
)
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
mask = (
(idx_n < BATCH)[:, None]
& (idx_y_h < OUT_H)[:, None]
& (idx_y_w < OUT_W)[:, None]
& (idx_y_c < GROUP_OUT_C)[None, :]
)
idx_n = idx_n[:, None]
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
idx_h = idx_y_h[:, None]
idx_w = idx_y_w[:, None]
# inductor generates a suffix
xindex = idx_w + (64*idx_h) + (4096*idx_c) + (262144*idx_n)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)
metadata: {'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'device_type': 'cuda', 'num_warps': 8, 'num_stages': 1, 'debug': True, 'cc': 75}
Traceback (most recent call last):
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 380, in _precompile_config
binary = triton.compile(*compile_args, **compile_kwargs)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/compiler/compiler.py", line 273, in compile
next_module = compile_ir(module, metadata)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 285, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 198, in make_llir
pm.run(mod)
IndexError: map::at