Getting Triton to generate all kernels

I’m looking at the the debug traces produced by TORCH_COMPILE_DEBUG=1 and I can see while most of the kernels are generated by Triton/C++ code generator in Inductor, some kernels (especially mm, addmm, etc) are offloaded into external libraries/templates etc. See the example below.

from torch._inductor.select_algorithm import extern_kernels
...
...
extern_kernels.mm(arg0_1, arg1_1, out=buf0)
...
...

I’m wondering if I can make Triton/C++ code generator emit code for 100% of my operators without using extern_kernels.

Thanks!

1 Like

No global flag for this as far as I know but poking around here will be your best bet https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py

1 Like

Thanks for pointing me in the right direction, this definitely seems to be the way. From what I’ve seen, gemm kernels are the ones that get predominantly handled by extern_kernels.

If I remove ATEN from config.max_autotune_gemm_backends and enable TORCHINDUCTOR_MAX_AUTOTUNE, then Triton kicks in and generates gemm kernels.

2 Likes

Hello @trusira @marksaroufim

I followed your lead and made the following changes:

torch._inductor.config.max_autotune_gemm_backends = "TRITON" # removed ATEN
torch._inductor.config.max_autotune = True

But I still get an error as follows:

  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/kernel/mm.py", line 156, in tuned_mm
    return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 991, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amodab01/anaconda3/envs/ml_training/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 723, in __call__
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: RuntimeError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: aten.mm.default
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[100], stride=[1]))
      ),
      FixedLayout('cuda', torch.float32, size=[1, 100], stride=[100, 1]),
      origins={view}
    )
  )
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[100, 100], stride=[100, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[100, 100], stride=[1, 100]),
      origins={permute}
    )
  )

#Repro

import torch
import torch._inductor.config
torch._inductor.config.trace.enabled = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
torch._inductor.config.max_autotune = True

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(100, 100)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.l(x))

m = ToyModel().to(device="cuda:0")

m = torch.compile(m)
input_tensor = torch.randn(100).to(device="cuda:0")
out = m(input_tensor)

Hello
Is it possible to get some directions related to how I can proceed regarding this? I want to get rid of the extern_kernels and just make Inductor generate Triton kernels for everything.

@amodab01 which pytorch version are you running?

Because your code runs just fine for me if I comment out # torch._inductor.config.trace.enabled = True

on this pytorch version torch 2.4.0.dev20240506+cu121

With it on though it’s an actual bug and we’re tracking it here torch._inductor.config.trace.enabled = True crashes · Issue #125642 · pytorch/pytorch · GitHub

I’ve tried it on 2.2.2 and 2.3.0. I’ll try it with torch 2.4.0.dev20240506+cu121, thanks

Thanks for this
I’m running into the same issue as mentioned in the bug now.
A quick follow-up, is it possible for me to access the Triton kernels generated and being benchmarked, even with torch._inductor.config.trace.enabled = False?

Yeah all you need to do is set TORCH_LOGS="output_code" python train.py and you’ll get the kernels printed

1 Like

Thanks @marksaroufim

Another thing I observed is when I add a Conv layer to this model, there’s a compile error irrespective of whether torch._inductor.config.trace.enabled is True or False.

Error:

Traceback (most recent call last):
  File "foo.py", line 28, in <module>
    out = m(input_tensor)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
    return fn(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 979, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 820, in _convert_frame
    result = inner_convert(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
    return _compile(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 701, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 568, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 173, in _fn
    return fn(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 515, in transform
    tracer.run()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
    return handle_graph_break(self, inst, speculation.reason)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 561, in handle_graph_break
    self.output.compile_subgraph(self, reason=reason)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1103, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1295, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1386, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 1367, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/__init__.py", line 1745, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1454, in compile_fx
    return aot_autograd(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 958, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 685, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 470, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 447, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 1358, in fw_compiler_base
    return inner_compile(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 483, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 779, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1692, in compile_to_fn
    return self.compile_to_module().call
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1635, in compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1587, in codegen
    self.scheduler = Scheduler(self.buffers)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1353, in __init__
    self.fuse_nodes()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1743, in fuse_nodes
    self.fuse_nodes_once()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1976, in fuse_nodes_once
    if not self.speedup_by_fusion(node1, node2):
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/scheduler.py", line 1881, in speedup_by_fusion
    choice_timings = multi_node.choice_timings
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3677, in choice_timings
    self._choice_timings = self._choice_timings_fn()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1110, in get_timings
    timings = do_autotuning(precompile_fn)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1079, in do_autotuning
    precompile_fn()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1047, in wait_on_futures
    next(iterator)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 621, in result_iterator
    yield fs.pop().result(end_time - time.monotonic())
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1033, in <lambda>
    lambda c: precompile_with_captured_stdout(c),
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 1029, in precompile_with_captured_stdout
    return choice.precompile()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/select_algorithm.py", line 779, in precompile
    self.bmreq.precompile()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/autotune_process.py", line 632, in precompile
    getattr(mod, self.kernel_name).precompile()
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 197, in precompile
    compiled_binary, launcher = self._precompile_config(
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 380, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/compiler/compiler.py", line 273, in compile
    next_module = compile_ir(module, metadata)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 285, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 198, in make_llir
    pm.run(mod)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
IndexError: map::at

Repro:

import torch
import torch._inductor.config
# torch._inductor.config.trace.enabled = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
torch._inductor.config.max_autotune = True

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.linear = torch.nn.Linear(262144, 100)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.relu(self.linear(x))

m = ToyModel().to(device="cuda:0")

m = torch.compile(m)
input_tensor = torch.randn(32,3,64,64).to(device="cuda:0")
out = m(input_tensor)

I turned on trace.enabled to generate a debug.log, some of its contents are as follows:

[triton_heuristics.py:382 ERROR] Triton compilation failed: Placeholder.DESCRIPTIVE_NAME
def triton_convolution(arg_X, arg_W, out_ptr0):
    KERNEL_H : tl.constexpr = 3
    KERNEL_W : tl.constexpr = 3
    STRIDE_H : tl.constexpr = 1
    STRIDE_W : tl.constexpr = 1
    PADDING_H : tl.constexpr = 1
    PADDING_W : tl.constexpr = 1
    GROUPS : tl.constexpr = 1
    UNROLL : tl.constexpr = False
    ALLOW_TF32 : tl.constexpr = True
    BLOCK_M : tl.constexpr = 1024
    BLOCK_N : tl.constexpr = 16
    BLOCK_K : tl.constexpr = 16
    X = arg_X
    W = arg_W

    # Tensor dimensions
    BATCH = 32
    IN_C = 3
    IN_H = 64
    IN_W = 64
    OUT_C = 64
    OUT_H = 64
    OUT_W = 64

    # Strides:
    stride_xn = 12288
    stride_xc = 4096
    stride_xh = 64
    stride_xw = 1
    stride_wc_out = 27
    stride_wc_in = 9
    stride_wh = 3
    stride_ww = 1

    nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    idx_y_w = nhw % OUT_W
    nh = nhw // OUT_W
    idx_y_h = nh % OUT_H
    idx_n = nh // OUT_H
    idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)


    group = 0
    GROUP_IN_C = IN_C
    GROUP_OUT_C = OUT_C


    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
    w_base = (
        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
    )

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)


    # Could be simplified, but slightly slower:
    # for i in range(KERNEL_H):
    #     for j in range(KERNEL_W):
    #         for k in range(0, GROUP_IN_C, BLOCK_K):
    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
    for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
        k = (ijk % BLOCK_K_COUNT) * BLOCK_K
        ij = ijk // BLOCK_K_COUNT
        i = ij // KERNEL_W
        j = ij % KERNEL_W

        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
        idx_x_c = tl.arange(0, BLOCK_K) + k

        x_ptrs = x_base + (
            (idx_x_h * stride_xh)[:, None]
            + (idx_x_w * stride_xw)[:, None]
            + (idx_x_c * stride_xc)[None, :]
        )
        mask_x = (
            (idx_n < BATCH)[:, None]
            & (idx_x_h >= 0)[:, None]
            & (idx_x_h < IN_H)[:, None]
            & (idx_x_w >= 0)[:, None]
            & (idx_x_w < IN_W)[:, None]
            & (idx_x_c < GROUP_IN_C)[None, :]
        )
        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)

        w_ptrs = w_base + (
            (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
        )
        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)



    mask = (
        (idx_n < BATCH)[:, None]
        & (idx_y_h < OUT_H)[:, None]
        & (idx_y_w < OUT_W)[:, None]
        & (idx_y_c < GROUP_OUT_C)[None, :]
    )
    idx_n = idx_n[:, None]
    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
    idx_h = idx_y_h[:, None]
    idx_w = idx_y_w[:, None]

    # inductor generates a suffix
    xindex = idx_w + (64*idx_h) + (4096*idx_c) + (262144*idx_n)
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)

metadata: {'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'device_type': 'cuda', 'num_warps': 8, 'num_stages': 1, 'debug': True, 'cc': 75}
Traceback (most recent call last):
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 380, in _precompile_config
    binary = triton.compile(*compile_args, **compile_kwargs)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/compiler/compiler.py", line 273, in compile
    next_module = compile_ir(module, metadata)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 285, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/home/amodab01/anaconda3/envs/test/lib/python3.8/site-packages/triton/backends/nvidia/compiler.py", line 198, in make_llir
    pm.run(mod)
IndexError: map::at

@amodab01 mind then opening a new issue Issues · pytorch/pytorch · GitHub most of the torch.compile devs are there

Sure, I’ve created an issue there torch._inductor.config.max_autotune_gemm_backends = "TRITON" crashes with Convolution layer · Issue #125728 · pytorch/pytorch · GitHub

@marksaroufim is there any particular reason, why PyTorch 2’s inductor’s codegen does not generate all kernels as triton kernel by default and there is a mix of triton and Aten (native PyTorch) kernels?

Any pointer to some documentation explaining the same?

1 Like

This is probably motivated by performance. If you already have high-performance kernels for common ops like GEMM, it makes sense to use them.