The high level answer is that:
torch.export.export()
captures a “higher level” graph, with a few properties:
(1) No decompositions have run yet (so we have not, e.g. decomposed scaled_dot_product_attention
into smaller operations)
(2) The output of torch.export.export
is a forward-only graph that is safe to serialize, load up later in python, and train on eagerly. It doesn’t actually capture the backward graph ahead of time, but instead you can re-use the eager autograd machinery to train on the generated forward graph when you load your exported program later
aot_export_module
(if you pass in the trace_joint=True
flag like you specified above) will trace through the autograd engine, giving you a backward graph ahead-of-time.
One consequence is: there are number of operators that we do not have derivative formulas for. Instead, autograd handles them by first decomposing them, and then running the derivative rule for each decomposed operator. So if you want to fully get out a forward + backward graph ahead of time, we will need to decompose any operators that we do not have derivative rules for.
aten.scaled_dot_product_attention
is one of these operators. In principle we could write a generic derivative rule for SDPA (cc @drisspg), but instead we have various backend implementations for sdpa that each have their own derivative rule.
In your example, it looks like for your particular input shapes, we aren’t able to use one of the dedicated kernel backends for SDPA (and are thus forced to use the “generic/math” backend, which decomposes SDPA into several smaller kernels like mm/softmax).
If I tweak your inputs to be:
# make q/k/v 4 dimensional
q, k, v = torch.randn(1, 1, 1024, 3, 128, device="cuda", requires_grad=True).unbind(-2)
then we are able to use one of the dedicated kernels, and I get out a different graph:
def forward(self, arg0_1, arg1_1, arg2_1):
_scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention.default(arg0_1, arg1_1, arg2_1, None, True)
getitem = _scaled_dot_product_efficient_attention[0]
getitem_1 = _scaled_dot_product_efficient_attention[1]
getitem_2 = _scaled_dot_product_efficient_attention[2]
getitem_3 = _scaled_dot_product_efficient_attention[3]; _scaled_dot_product_efficient_attention = None
detach = torch.ops.aten.detach.default(getitem); detach = None
detach_1 = torch.ops.aten.detach.default(getitem)
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3 = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4 = torch.ops.aten.detach.default(detach_3); detach_3 = None
sum_1 = torch.ops.aten.sum.default(getitem); getitem = None
ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand = torch.ops.aten.expand.default(ones_like, [1, 1, 1024, 128]); ones_like = None
detach_5 = torch.ops.aten.detach.default(detach_4); detach_4 = None
detach_6 = torch.ops.aten.detach.default(detach_5); detach_5 = None
detach_7 = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_8 = torch.ops.aten.detach.default(detach_7); detach_7 = None
_scaled_dot_product_efficient_attention_backward = torch.ops.aten._scaled_dot_product_efficient_attention_backward.default(expand, arg0_1, arg1_1, arg2_1, None, detach_8, getitem_1, getitem_2, getitem_3, 0.0, [True, True, True, False]); expand = arg0_1 = arg1_1 = arg2_1 = detach_8 = getitem_1 = getitem_2 = getitem_3 = None
getitem_4 = _scaled_dot_product_efficient_attention_backward[0]
getitem_5 = _scaled_dot_product_efficient_attention_backward[1]
getitem_6 = _scaled_dot_product_efficient_attention_backward[2]
getitem_7 = _scaled_dot_product_efficient_attention_backward[3]; _scaled_dot_product_efficient_attention_backward = getitem_7 = None
return (sum_1, getitem_4, getitem_5, getitem_6)