Hi, I am trying to implement somwhat optimal autograd function for fused layer norm and linear weight. I I can in theory write it manually using provided instructions, but …
But here is the gotcha - that will be wildly inefficient, as fused layer norm is much faster than manual. Same for backward. Fused layernorm + linear is more efficient, but seemingly I can’t use it either.
My dream implementation would be (very pseudocode-ish)
def forward(x):
ctx.save_for_backward(x) # store only input for bwd
y = F.layer_norm(x, None, None)
out = F.linear(y, W, b)
# would be perfect to use fused kernel for the prev two lines if possible
return out
def backward(out_grad):
x = ... # retrieved from ctx
y = F.layer_norm(x, None, None)
y_grad, W_grad, b_grad = F.linear.backward(x, W, b, x_grad) # or something similar
x_grad, _, _ = F.layer_norm.backward(x, None, None, y_grad) # or something similar
# return gradients
So my point is – I want to use fused kernels, same as used internally by torch, but I can’t find a way to use those directly.