import cutlass
import torch
from torch.autograd import Function
class GroupedGemm(Function):
@staticmethod
def forward(ctx, As, Bs):
# Validate inputs
assert len(As) == len(Bs), "Number of A and B matrices must match"
for A, B in zip(As, Bs):
assert A.size(-1) == B.size(-2), f"Incompatible dimensions for GEMM: {A.size()} and {B.size()}"
# Save inputs for backward
ctx.save_for_backward(*As, *Bs)
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=As[0].dtype, layout=cutlass.LayoutType.RowMajor)
# Prepare Cs and Ds for CUTLASS GEMM
Cs = [torch.zeros(A.size(0), B.size(-1), device=A.device, dtype=A.dtype) for A, B in zip(As, Bs)]
Ds = [torch.empty_like(C) for C in Cs]
# Run CUTLASS grouped GEMM
plan.run(As, Bs, Cs, Ds)
result = torch.cat([d for d in Ds], dim=0)
return result
@staticmethod
def backward(ctx, grad):
"""
Compute gradients using CUTLASS plan.run in the backward pass.
"""
grad = grad.contiguous()
num_problems = len(ctx.saved_tensors) // 2
As = ctx.saved_tensors[:num_problems]
Bs = ctx.saved_tensors[num_problems:]
# Prepare CUTLASS plan
plan = cutlass.op.GroupedGemm(element=grad.dtype, layout=cutlass.LayoutType.RowMajor)
# Compute gradient w.r.t. As
agrad_list = []
if ctx.needs_input_grad[0]:
for grad_i, B in zip(torch.split(grad, [A.size(0) for A in As]), Bs):
B_transposed = B.transpose(-2, -1)
A_grad = torch.zeros(grad_i.size(0), B_transposed.size(1), device=grad.device, dtype=grad.dtype)
plan.run([grad_i], [B_transposed], [A_grad], [A_grad])
agrad_list.append(A_grad)
# Compute gradient w.r.t. Bs
bgrad_list = []
if ctx.needs_input_grad[1]:
for A, grad_i in zip(As, torch.split(grad, [A.size(0) for A in As])):
A_transposed = A.transpose(-2, -1)
B_grad = torch.zeros_like(Bs[0])
plan.run([A_transposed], [grad_i], [B_grad], [B_grad])
bgrad_list.append(B_grad)
# Return gradients
agrad = agrad_list if ctx.needs_input_grad[0] else None
bgrad = bgrad_list if ctx.needs_input_grad[1] else None
return agrad, bgrad
def test_lora_gradient():
A = torch.randn(10, 5, requires_grad=True, device="cuda")
B = torch.randn(5, 2, requires_grad=True, device="cuda")
C = GroupedGemm.apply([A], [B])
print(C, f"C require grad is {C[0].requires_grad}")
loss = C.sum()
loss.backward()
print(A.grad)
print(B.grad)
assert A.grad is not None, "Gradient not computed for A"
assert B.grad is not None, "Gradient not computed for B"
test_lora_gradient()
I’m trying to build a custom grouped GEMM operator by wrapping around CUTLASS grouped gemm kernel, but my output after cutlass grouped gemm call does not have grad_fn, causing problem during backward. The above test case would yield the following error:
Traceback (most recent call last):
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 87, in <module>
test_lora_gradient()
File "/home/ubuntu/torchtune/tests/torchtune/modules/peft/test_grouped_gemm_backward.py", line 79, in test_lora_gradient
loss.backward()
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/home/ubuntu/.conda/envs/torchtune/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
What is the best way to wrap a Cutlass kernel?