Torch.compile with custom Triton kernel

Is there a way to use torch.compile with a model that uses a custom, autotuned Triton kernel in the middle? Concretely, I want to use this MoE layer in Languini Kitchen, but I get a gigantic unparsable error message from below. Do you have any suggestions on how to even start debugging this?

[2023-11-30 14:40:53,370] [7/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-11-30 14:40:53,370] [7/0] torch._dynamo.variables.higher_order_ops: [ERROR] call_function partial in skip_files /usr/lib/python3.10/functools.py
Traceback (most recent call last):
  File "/home/robert/languini-kitchen/languini/common_lib/throughput.py", line 132, in <module>
    main()
  File "/home/robert/languini-kitchen/languini/common_lib/throughput.py", line 82, in main
    result = throughput_utils.throughput_test(config=config, model=model)
  File "/home/robert/languini-kitchen/languini/common_lib/throughput_utils.py", line 79, in throughput_test
    step(inputs, targets, state)
  File "/home/robert/languini-kitchen/languini/common_lib/throughput_utils.py", line 65, in step
    logits, _ = model(inputs, state)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/languini-kitchen/languini/projects/moe/model.py", line 91, in forward
    x = layer(x, log=log)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/languini-kitchen/languini/projects/moe/lib.py", line 190, in forward
    mlp_x = self.mlp(self.ln2(x, log=log), log=log)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/languini-kitchen/languini/projects/moe/lib.py", line 159, in forward
    x, reg_loss = self.moe(x)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/moe_layer_simple.py", line 166, in forward
    scores = self.compute_scores(input, sel_indices)
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/moe_layer_simple.py", line 97, in compute_scores
    scores = cvmm(input, index, self.keys)
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 473, in cvmm
    return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index, sel.reduction_weight)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 409, in forward
    res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 351, in cvmm_triton
    cvmm_kernel[grid](
  File "/home/robert/languini-kitchen/languini/projects/moe/moe_layer/cvmm.py", line 351, in <resume in cvmm_triton>
    cvmm_kernel[grid](
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 99, in run
    bench_start = time.time()
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <resume in run>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 74, in _bench
    current = dict(meta, **config.kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in <resume in _bench>
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/home/robert/.local/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
    fn()
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 4, in cvmm_kernel
  File "<string>", line 4, in <resume in cvmm_kernel>
  File "<string>", line 4, in <resume in cvmm_kernel>
  File "<string>", line 4, in <resume in cvmm_kernel>
  File "<string>", line 4, in <resume in cvmm_kernel>
  File "<string>", line 4, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 6, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 20, in <resume in cvmm_kernel>
  File "<string>", line 21, in <resume in cvmm_kernel>
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 173, in _pinned_memory_of
    if hasattr(arg, "is_pinned"):
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 953, in call_isinstance
    arg_type = arg.python_type()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 221, in python_type
    raise NotImplementedError(f"{self} has no type")
torch._dynamo.exc.InternalTorchDynamoError: GetAttrVariable(TensorVariable(), is_pinned) has no type

from user code:
   File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in <resume in _pinned_memory_of>
    if isinstance(arg.is_pinned, Callable):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
1 Like

Iā€™m not sure if thereā€™s an elegant way of doing this but I believe you need to somehow register the custom kernel with the codecache. When the Triton kernels are automatically generated, this happens internally so the correct kernels are mapped for their corresponding operators.
Iā€™m also interested to hear what the community has to say.

To me, the problem seems to be the autotuning and not the kernel itself. I also donā€™t see how the autotuning could work in the middle of a compiled kernel, unless Triton can communicate with PyTorch and recompile the whole model for all configurations. My colleague also tried a different setup with a triton kernel without tuning and he said it worked. I did not have time to test it yet, but I will try some minimal example.

Triton can and does communicate with Pytorch for PTX/cubin codegen. Furthermore, I see Pytorch implements a lightweight version of Tritonā€™s CachingAutotuner class, even though, Iā€™m a little confused as to who (between Triton and Pytorch) actually handles kernel launching during runtime. I asked this in a different post here.

AFAIK, the autotuning apparatus is used irrespective of whether youā€™re autotuning multiple configs or not. In the single kernel (i.e., no autotune) case, it will just generate a single kernel and launch the kernel. See def run here. It handles the multiple config case separately.

  if len(self.launchers) != 1:
            if len(self.launchers) == 0:
                self.precompile()
            if len(self.launchers) > 1:
                self.autotune_to_one_config(*args, grid=grid, **kwargs)

However, Iā€™m not certain at this point, if TorchInductor actually reuses Triton JIT runtime or has its own mechanism to launch kernels.

It is actually not the autotuning. Hereā€™s a minimal example based on the triton vector addition example. Tested with torch==2.1.0, fails with the same error message as before.

import torch

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def add_fwd(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output


class TritonAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        return add_fwd(x, y)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, grad_output

triton_add = TritonAdd.apply

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return triton_add(x, y)

model = torch.compile(Model().cuda())

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = model(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Output:

Traceback (most recent call last):
  File "/home/robert/rnn_generalization_test/compile_test.py", line 53, in <module>
    output_triton = model(x, y)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/rnn_generalization_test/compile_test.py", line 44, in forward
    return triton_add(x, y)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/robert/rnn_generalization_test/compile_test.py", line 31, in forward
    return add_fwd(x, y)
  File "/home/robert/rnn_generalization_test/compile_test.py", line 24, in add_fwd
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
  File "/home/robert/rnn_generalization_test/compile_test.py", line 24, in <resume in add_fwd>
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
  File "<string>", line 4, in add_kernel
  File "<string>", line 4, in <resume in add_kernel>
  File "<string>", line 4, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 6, in <resume in add_kernel>
  File "<string>", line 20, in <resume in add_kernel>
  File "<string>", line 20, in <resume in add_kernel>
  File "<string>", line 20, in <resume in add_kernel>
  File "<string>", line 21, in <resume in add_kernel>
  File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 173, in _pinned_memory_of
    if hasattr(arg, "is_pinned"):
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 618, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 953, in call_isinstance
    arg_type = arg.python_type()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 221, in python_type
    raise NotImplementedError(f"{self} has no type")
torch._dynamo.exc.InternalTorchDynamoError: GetAttrVariable(TensorVariable(), is_pinned) has no type

from user code:
   File "/home/robert/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in <resume in _pinned_memory_of>
    if isinstance(arg.is_pinned, Callable):

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Thereā€™s some ongoing work to make torch.compile understand triton kernels (and in-line a user triton kernel into the inductor generated output). You can probably try it out with a nightly

Thanks for the tip! I tried out, the previous simple example compiled. However, the full version still doesnā€™t work. First, the kernel doesnā€™t compile at all anymore with some MLIR error, however after reinstalling triton, the MLIR error is gone, but the rest of the errors stay. Here are some traces (for the current, non-minimal code):

After reinstalling Triton:

is there a particular github issue/PR I could follow for that work?

There isnā€™t currently a github tracker, but some of the recent PRā€™s:

I checked with Oguz (the PR author), he said feel free to file github issues (preferably with minimal reproductions) if you run into any issues.

Hi!

First, thanks for the info.

I came up with a more minimal failing example, but it is still not super short. I removed most of the logic from the CVMM kernel, and this is what I am left with:

import torch

import triton
import triton.language as tl

# CVMM from: https://github.com/RobertCsordas/moe_layer/blob/master/triton_src/moe_layer/cvmm.py, simplified

from typing import Union, Optional
import torch
from dataclasses import dataclass
import triton
import triton.language as tl


@dataclass
class CVMMSel:
    raw_sel: torch.Tensor
    sel: torch.Tensor
    sel_index: torch.Tensor
    out_index: Optional[torch.Tensor] = None


def cvmm_prepare_sel(sel: torch.Tensor, n_experts: int) -> CVMMSel:
    fsel = sel.flatten()
    ssel, sel_index = fsel.sort()
    return CVMMSel(sel, ssel.view_as(sel), sel_index, None)



@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K', 'float32', 'allow_tf32']
)
@triton.jit
def cvmm_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_am, stride_ak,
    stride_bo, stride_bk, stride_bn,
    stride_cm, stride_cn,
    stride_index, stride_sel, stride_out_index,
    float32: tl.constexpr, allow_tf32: tl.constexpr,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)

    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_n = (pid % num_pid_in_group) // group_size_m

    pid_m = first_pid_m + (pid % group_size_m)

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M

    remap_offs_am = tl.load(index_ptr + stride_index * offs_am)

    # Create offset pointers
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    if not float32:
        c = accumulator.to(tl.float16)
    else:
        c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)


    # !!!!!! Removing this IF solves the "RuntimeError: CUDA error: an illegal memory access was encountered" problem,
    # even though it is always False in this example !!!!!!
    # To test it, keep the else branch.
    if out_index_ptr is not None:
        remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)
    else:
        remap_offs_cm = remap_offs_am

    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)



def cvmm_triton(x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_dtype: torch.dtype, out_index: Optional[torch.Tensor] = None):
    x = x.flatten(end_dim=-2)
    assert x.shape[-1] == keys.shape[1]

    sel_shape = sel.shape
    sel = sel.flatten()

    M = sel.shape[0]
    O, K, N = keys.shape
    # Allocates output.
    out = torch.empty((M, N), device=x.device, dtype=out_dtype)

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )

    cvmm_kernel[grid](
        x, keys, out, sel_index, sel, out_index,
        M, N, K,
        x.stride(0), x.stride(1),
        keys.stride(0), keys.stride(1), keys.stride(2),
        out.stride(0), out.stride(1),
        sel_index.stride(0), sel.stride(0), out_index.stride(0) if out_index is not None else 0,
        float32=out.dtype==torch.float32, allow_tf32=False, #torch.backends.cuda.matmul.allow_tf32
    )

    return out.view(*sel_shape, N)


def cvmm_triton_backward(x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, grads: torch.Tensor, n_experts: int, key_dtype: torch.dtype,
                         op_float16: bool, out_index: Optional[torch.Tensor] = None):
    x = x.flatten(end_dim=-2)
    x = x.transpose(0, 1)

    grads = grads.flatten(end_dim=-2)
    sel = sel.flatten()

    M, _ = x.shape
    K, N = grads.shape

    out = torch.zeros((n_experts, M, N), device=x.device, dtype=key_dtype)
    return out



class CVMM(torch.autograd.Function):
    warned = False

    @staticmethod
    def forward(ctx, x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_index: Optional[torch.Tensor] = None):
        ctx.save_for_backward(x, keys, sel, sel_index, out_index)

        out_type = torch.float16 if torch.is_autocast_enabled() else x.dtype
        res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
        ctx.op_type = out_type
        ctx.keys_type = keys.dtype
        ctx.is_autocast = torch.is_autocast_enabled()
        return res

    @staticmethod
    def backward(ctx, grad_output):
        x, keys, sel, sel_index, out_index = ctx.saved_tensors

        keys_dt = keys
        grad_w = cvmm_triton_backward(x, sel_index, sel, grad_output, keys_dt.shape[0], ctx.keys_type, ctx.is_autocast, out_index=out_index)

        bw_index = sel_index if out_index is None else out_index
        bw_index_out = None
        grad_x_full = cvmm_triton(grad_output, bw_index, sel, keys_dt.transpose(1,2), ctx.op_type, bw_index_out)

        grad_x = grad_x_full.view(*x.shape[:-1], -1, x.shape[-1])
        grad_x = grad_x.view_as(x)

        return grad_x, None, None, grad_w, None


def cvmm(x: torch.Tensor, sel: Union[torch.Tensor, CVMMSel], keys: torch.Tensor):
    if not isinstance(sel, CVMMSel):
        sel = cvmm_prepare_sel(sel, keys.shape[0])

    return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index)

# Compile test


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, sel, w):
        return cvmm(x, sel, w)

model = torch.compile(Model().cuda())
# model = Model().cuda()


torch.manual_seed(0)
n_experts = 8
n_channels = 64
expert_size = 64
bs = 64

device = torch.device("cuda")
dtype = torch.float16
atol_tresh = 1e-2

keys = torch.nn.Parameter(torch.randn(n_experts, n_channels, expert_size, dtype=dtype, device=device))
keys = keys.transpose(1,2).contiguous().transpose(1,2)
testvec = torch.randn(bs, n_channels, dtype=dtype, device=device)
sel = torch.randint(0, n_experts, (bs,), dtype=torch.int32, device=device)

print(model(testvec, sel, keys).shape)

There are two independent issues: the ā€œtorch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict modeā€ errors, and the ā€œRuntimeError: CUDA error: an illegal memory access was encounteredā€ error. None of them happens without compilation.

Interestingly, the illegal memory access error is caused by an if statement that is never true. I marked it with

# !!!!!! Removing this IF solves the "RuntimeError: CUDA error: an illegal memory access was encountered" problem,
# even though it is always False in this example !!!!!!

If the else branch is kept only, this minimal example runs, although it still prints the ā€œtorch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict modeā€ errors.

Please find the output of the script here.

I created a new issue for this: Custom Triton kernel "CUDA error: an illegal memory access was encountered" with Torch 2.2.0 nightly Ā· Issue #115344 Ā· pytorch/pytorch Ā· GitHub