import cutlass
import torch
from torch.autograd import Function
class GroupedGemm(Function):
def forward(ctx, As, Bs):
# Validate inputs
assert len(As) == len(Bs), "Number of A and B matrices must match"
for A, B in zip(As, Bs):
assert A.size(-1) == B.size(-2), f"Incompatible dimensions for GEMM: {A.size()} and {B.size()}"
# Save inputs for backward
ctx.save_for_backward(*As, *Bs)
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=As[0].dtype, layout=cutlass.LayoutType.RowMajor)
# Prepare Cs and Ds for CUTLASS GEMM
Cs = [torch.zeros(A.size(0), B.size(-1), device=A.device, dtype=A.dtype) for A, B in zip(As, Bs)]
Ds = [torch.empty_like(C) for C in Cs]
# Run CUTLASS grouped GEMM, Bs, Cs, Ds)
result =[d for d in Ds], dim=0)
return result
def backward(ctx, grad):
Compute gradients using CUTLASS in the backward pass.
grad = grad.contiguous()
num_problems = len(ctx.saved_tensors) // 2
As = ctx.saved_tensors[:num_problems]
Bs = ctx.saved_tensors[num_problems:]
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=grad.dtype, layout=cutlass.LayoutType.RowMajor)
# Compute gradient w.r.t. As
agrad_list = []
if ctx.needs_input_grad[0]:
for grad_i, B in zip(torch.split(grad, [A.size(0) for A in As]), Bs):
B_transposed = B.transpose(-2, -1)
A_grad = torch.zeros(grad_i.size(0), B_transposed.size(1), device=grad.device, dtype=grad.dtype)[grad_i], [B_transposed], [A_grad], [A_grad])
# Compute gradient w.r.t. Bs
bgrad_list = []
if ctx.needs_input_grad[1]:
for A, grad_i in zip(As, torch.split(grad, [A.size(0) for A in As])):
A_transposed = A.transpose(-2, -1)
B_grad = torch.zeros_like(Bs[0])[A_transposed], [grad_i], [B_grad], [B_grad])
# Return gradients
agrad = agrad_list if ctx.needs_input_grad[0] else None
bgrad = bgrad_list if ctx.needs_input_grad[1] else None
return agrad, bgrad
def test_lora_gradient():
A = torch.randn(10, 5, requires_grad=True, device="cuda")
B = torch.randn(5, 2, requires_grad=True, device="cuda")
C = GroupedGemm.apply([A], [B])
print(C, f"C require grad is {C[0].requires_grad}")
loss = C.sum()
assert A.grad is not None, "Gradient not computed for A"
assert B.grad is not None, "Gradient not computed for B"
I’m trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:
Traceback (most recent call last):
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/", line 87, in <module>
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/", line 79, in test_lora_gradient
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/", line 581, in backward
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/", line 347, in backward
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
What is the best way to wrap a Cutlass kernel?