Cutlass kernel causes no grad in backward

import cutlass

import torch
from torch.autograd import Function

class GroupedGemm(Function):
    def forward(ctx, As, Bs):
        # Validate inputs
        assert len(As) == len(Bs), "Number of A and B matrices must match"
        for A, B in zip(As, Bs):
            assert A.size(-1) == B.size(-2), f"Incompatible dimensions for GEMM: {A.size()} and {B.size()}"

        # Save inputs for backward
        ctx.save_for_backward(*As, *Bs)

        # Prepare CUTLASS plan
        plan = cutlass.op.GroupedGemm(element=As[0].dtype, layout=cutlass.LayoutType.RowMajor)

        # Prepare Cs and Ds for CUTLASS GEMM
        Cs = [torch.zeros(A.size(0), B.size(-1), device=A.device, dtype=A.dtype) for A, B in zip(As, Bs)]
        Ds = [torch.empty_like(C) for C in Cs]

        # Run CUTLASS grouped GEMM, Bs, Cs, Ds)

        result =[d for d in Ds], dim=0)

        return result

    def backward(ctx, grad):
        Compute gradients using CUTLASS in the backward pass.
        grad = grad.contiguous()
        num_problems = len(ctx.saved_tensors) // 2
        As = ctx.saved_tensors[:num_problems]
        Bs = ctx.saved_tensors[num_problems:]

        # Prepare CUTLASS plan
        plan = cutlass.op.GroupedGemm(element=grad.dtype, layout=cutlass.LayoutType.RowMajor)

        # Compute gradient w.r.t. As
        agrad_list = []
        if ctx.needs_input_grad[0]:
            for grad_i, B in zip(torch.split(grad, [A.size(0) for A in As]), Bs):
                B_transposed = B.transpose(-2, -1)
                A_grad = torch.zeros(grad_i.size(0), B_transposed.size(1), device=grad.device, dtype=grad.dtype)
      [grad_i], [B_transposed], [A_grad], [A_grad])

        # Compute gradient w.r.t. Bs
        bgrad_list = []
        if ctx.needs_input_grad[1]:
            for A, grad_i in zip(As, torch.split(grad, [A.size(0) for A in As])):
                A_transposed = A.transpose(-2, -1)
                B_grad = torch.zeros_like(Bs[0])
      [A_transposed], [grad_i], [B_grad], [B_grad])

        # Return gradients
        agrad = agrad_list if ctx.needs_input_grad[0] else None
        bgrad = bgrad_list if ctx.needs_input_grad[1] else None
        return agrad, bgrad

def test_lora_gradient():
    A = torch.randn(10, 5, requires_grad=True, device="cuda")
    B = torch.randn(5, 2, requires_grad=True, device="cuda")
    C = GroupedGemm.apply([A], [B])
    print(C, f"C require grad is {C[0].requires_grad}")
    loss = C.sum()


    assert A.grad is not None, "Gradient not computed for A"
    assert B.grad is not None, "Gradient not computed for B"


I’m trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:

Traceback (most recent call last):
  File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/", line 87, in <module>
  File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/", line 79, in test_lora_gradient
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/", line 581, in backward
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/", line 347, in backward
  File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

What is the best way to wrap a Cutlass kernel?

The issue is that you are passing in the inputs to the custom autograd Function via lists. Custom autograd Function does not traverse through these lists to look for inputs that require grad. You need to unpack your lists and pass via raw inputs. (If you’d like to accept a variable number of inputs, you should use python varargs).

