Custom Triton kernel with torch compile returns incorrect outputs

I have a custom triton kernel that is wrapped by an Autograd function. When torch compiling this function, the outputs are incorrect (and do not match the outputs of the non-compiled version). The outputs are correct on when the function is run without compiling.

I have localized the issue. Here is the Autograd function definition:

class FusedChunkLinearAttentionFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, scale, initial_state, output_final_state):
        batch_size, n_heads, seq_len, d_head_qk = q.shape
        d_head_v = v.shape[-1]
        ctx.scale = scale

        # Convert PyTorch dtype to Triton dtype
        if q.dtype == torch.float16:
            triton_dtype = tl.float16
        elif q.dtype == torch.bfloat16:
            triton_dtype = tl.bfloat16
        else:
            triton_dtype = tl.float32

        # Initialize the output tensors
        BK = min(triton.next_power_of_2(d_head_qk), 64)
        NK = triton.cdiv(d_head_qk, BK)
        o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
        if output_final_state:
            final_state = q.new_empty(
                batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False
            )
        else:
            final_state = None

        grid = lambda args: (  # noqa: E731
            triton.cdiv(v.shape[-1], args["BV"]),
            NK,
            batch_size * n_heads,
        )
        fused_chunk_linear_attn_fwd_kernel[grid](
            q,
            k,
            v,
            o,
            initial_state,
            final_state,
            scale,
            q.stride(1),
            q.stride(2),
            q.stride(3),
            v.stride(1),
            v.stride(2),
            v.stride(3),
            batch_size,
            n_heads,
            seq_len,
            d_head_qk,
            d_head_v,
            BK=BK,
            USE_INITIAL_STATE=initial_state is not None,
            STORE_FINAL_STATE=output_final_state,
            DTYPE=triton_dtype,
        )
        ctx.save_for_backward(q, k, v, initial_state)
        ctx.triton_dtype = triton_dtype

        # This code returns incorrect outputs when running with torch compile
        # o = o.sum(0)
        # return o.to(q.dtype), final_state

        # If I break the compilation graph with print(), then the code returns correct outputs
        if NK > 1:
            print("")  # this is needed to break the torch compile graph
            o = o.sum(0)
            return o.to(q.dtype), final_state
        else:
            # in our setting we can avoid this since NK will be one
            return o[0].to(q.dtype), final_state

The issue stems from the final

o.sum(0)

since it sums up values that are computed by independent threadblocks on the GPU. torch compile tries to compile the full function into a single graph, yielding incorrect values. I speculate that it probably tries to do the sum before all the values are actually available.

Instead if I force break the graph before the sum operation, then the outputs are correct and match the eager mode values.

Two questions:

  • Is that behaviour expected that the outputs are incorrect in this case? I would expect torch compile to wait until the triton kernel fully finishes and correctly compute the sum.
  • If yes, is there a more principled way to break the graph instead of using a print() function?

Based on these docs you might want to use torch.library.register_autograd instead of autograd.Function for your custom operator:

Use torch.library.register_autograd to add training support for an operator. Prefer this over directly using torch.autograd.Function; some compositions of autograd.Function with PyTorch operator registration APIs can lead to (and has led to) silent incorrectness when composed with torch.compile.

Hi @ptrblck

Thanks for the info. I tried to rewrite the function to use the custom torch ops, library, however I run into the following error, when trying to compile the model:

  fused_chunk_linear_attn_bwd_kernel[grid](
  File "/home/jwx1126697/anaconda3/envs/nightly/lib/python3.11/site-packages/triton/runtime/jit.py", line 207, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jwx1126697/anaconda3/envs/nightly/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 123, in run
    if key not in self.cache:
       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jwx1126697/anaconda3/envs/nightly/lib/python3.11/site-packages/torch/__init__.py", line 311, in __hash__
    raise TypeError("unhashable type: non-nested SymInt")
TypeError: unhashable type: non-nested SymInt

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jwx1126697/workspace/src/finetuning_llm/ops/triton_dijiang_kernel.py", line 748, in <module>
    torch.library.opcheck(fused_chunk_linear_attn_kernel, example)
  File "/home/jwx1126697/anaconda3/envs/nightly/lib/python3.11/site-packages/torch/library.py", line 931, in opcheck
    return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jwx1126697/anaconda3/envs/nightly/lib/python3.11/site-packages/torch/testing/_internal/optests/generate_tests.py", line 664, in opcheck
    raise OpCheckError(
torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): test_aot_dispatch_dynamic failed with unhashable type: non-nested SymInt (scroll up for stack trace)

As a workaround for now, I simplified the parallelization of the triton kernel, such that I dropped the

o = o.sum(0)

part of the code, which was the one creating issues with torch compile. This is not ideal, since it reduces the performance of the kernel, so I would be still interested to fix the issue above. Any help is appreciated!

Here is the code defining the torch custom op for the triton kernel:


@torch.library.custom_op("mylib::fused_chunk_linear_attn_kernel", mutates_args=())
def fused_chunk_linear_attn_kernel(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor | None,
    output_final_state: bool,
) -> List[torch.Tensor]:
    batch_size, n_heads, seq_len, d_head_qk = q.shape
    d_head_v = v.shape[-1]

    # Convert PyTorch dtype to Triton dtype
    if q.dtype == torch.float16:
        triton_dtype = tl.float16
    elif q.dtype == torch.bfloat16:
        triton_dtype = tl.bfloat16
    else:
        triton_dtype = tl.float32

    # Initialize the output tensors
    BK = min(triton.next_power_of_2(d_head_qk), 64)
    NK = triton.cdiv(d_head_qk, BK)
    o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
    final_state = q.new_empty(
        batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False
    )

    grid = lambda args: (  # noqa: E731
        triton.cdiv(v.shape[-1], args["BV"]),
        NK,
        batch_size * n_heads,
    )
    fused_chunk_linear_attn_fwd_kernel[grid](
        q,
        k,
        v,
        o,
        initial_state,
        final_state,
        scale,
        q.stride(1),
        q.stride(2),
        q.stride(3),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        batch_size,
        n_heads,
        seq_len,
        d_head_qk,
        d_head_v,
        BK=BK,
        USE_INITIAL_STATE=initial_state is not None,
        STORE_FINAL_STATE=output_final_state,
        DTYPE=triton_dtype,
    )
    o = o.sum(0).to(q.dtype)
    return o, final_state


@fused_chunk_linear_attn_kernel.register_fake
def _(q, k, v, scale, initial_state, output_final_state):
    assert q.device == k.device == v.device
    assert q.is_contiguous()
    assert k.is_contiguous()
    assert v.is_contiguous()

    B, H, L, D = k.shape
    _, _, _, E = v.shape
    return torch.empty_like(q), torch.empty(B, H, D, E, dtype=torch.float32, device=q.device)


def setup_context_bwd_pass(ctx, inputs, output):
    q, k, v, scale, initial_state, output_final_state = inputs
    if q.dtype == torch.float16:
        triton_dtype = tl.float16
    elif q.dtype == torch.bfloat16:
        triton_dtype = tl.bfloat16
    else:
        triton_dtype = tl.float32
    ctx.triton_dtype = triton_dtype
    # ctx.saved_tensors = (q, k, v, initial_state)
    ctx.q = q
    ctx.k = k
    ctx.v = v
    ctx.initial_state = initial_state
    ctx.scale = scale


def backward(ctx, grad_outputs):
    do, d_final_state = grad_outputs
    q = ctx.q
    k = ctx.k
    v = ctx.v
    # breakpoint()
    initial_state = ctx.initial_state
    # q, k, v, initial_state = ctx.saved_tensors
    batch_size, n_heads, seq_len, d_head_qk = q.shape
    d_head_v = v.shape[-1]
    scale = ctx.scale

    # Hard code the head dim block sizes since I need to knwo the amount of blocks to define
    # the dize of the gradient tensors
    # ALTERNATIVE: I coudl initialize the tensord dq, dk, dv with a MAX_NV/MAX_NK size and
    # with zeros, and then sum it all up. This would allow us to also autotune the sizes
    # of these blocks

    # NOTE: triton.next_power_of_2 does not work with torch custom ops library
    # BK = min(triton.next_power_of_2(d_head_qk), 64)
    # BV = min(triton.next_power_of_2(d_head_v), 64)
    # NK = triton.cdiv(d_head_qk, BK)
    # NV = triton.cdiv(d_head_v, BV)

    # Tested fixing the values to prevent the SymInt issue
    BK = 64
    BV = 64
    NK = (d_head_qk + BK - 1) // BK
    NV = (d_head_v + BV - 1) // BV

    # Initialize the output tensors
    dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
    dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
    dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
    # grid = (NV, NK, batch_size * n_heads)
    grid = (int(NV), int(NK), int(batch_size * n_heads))
    fused_chunk_linear_attn_bwd_kernel[grid](
        q,
        k,
        v,
        do,
        dq,
        dk,
        dv,
        initial_state,
        q.stride(1),
        q.stride(2),
        q.stride(3),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        batch_size,
        n_heads,
        seq_len,
        d_head_qk,
        d_head_v,
        scale,
        BK=BK,
        BV=BV,
        USE_INITIAL_STATE=initial_state is not None,
        DTYPE=ctx.triton_dtype,
    )
    dk = dk.sum(0)
    dq = dq.sum(0)
    dv = dv.sum(0)

    return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None


fused_chunk_linear_attn_kernel.register_autograd(backward, setup_context=setup_context_bwd_pass)

I don’t know what’s causing the problem and see a few other issues still being open, such as this one. You could create a new GitHub issue posting your minimal code reproducing the issue so the code owners can take a look at it.