Efficient custom backward in 2024

Hi, I am trying to implement somwhat optimal autograd function for fused layer norm and linear weight. I I can in theory write it manually using provided instructions, but …

But here is the gotcha - that will be wildly inefficient, as fused layer norm is much faster than manual. Same for backward. Fused layernorm + linear is more efficient, but seemingly I can’t use it either.

My dream implementation would be (very pseudocode-ish)

def forward(x):
    ctx.save_for_backward(x) # store only input for bwd
    y = F.layer_norm(x, None, None)
    out = F.linear(y, W, b) 
    # would be perfect to use fused kernel for the prev two lines if possible
    return out 

def backward(out_grad):
    x = ... # retrieved from ctx
    y = F.layer_norm(x, None, None)
    y_grad, W_grad, b_grad = F.linear.backward(x, W, b, x_grad) # or something similar
    x_grad, _, _ = F.layer_norm.backward(x, None, None, y_grad) # or something similar
    # return gradients

So my point is – I want to use fused kernels, same as used internally by torch, but I can’t find a way to use those directly.