The forward graphs captured by torch.export and aot_export_module are different

I’d like to capture forward graph of my module using dynamo. I tried two approaches but they return me with different results.

  1. torch.export
  2. aot_export_module

Specifically, aot_export_module breaks torch.nn.functional.scaled_dot_product_attention into two torch.mm whereas torch.export retains the fused kernel. Below is a minimal reproducible example.

from torch.export import export
class myModule(torch.nn.Module):
    def forward(self, q, k, v):
        return (torch.sum(torch.nn.functional.scaled_dot_product_attention(q, k, v)),)

my_module = myModule()
q, k, v = torch.randn(1024, 3, 128, device="cuda", requires_grad=True).unbind(1)
sg1 = export(my_module, args=(q, k, v))
joint_gm, joint_gs = aot_export_module(
    my_module, (q, k, v), trace_joint=True, output_loss_index=0, decompositions=None)
print(sg1)
print(joint_gm)

I think both approaches capture the forward graph using dynamo? Therefore I’m a bit confused why they would return different results.
Also, if using torch.compile, the fused attention is retained so I’m wondering if it’s broke into parts, when will the two torch.mm get fused again?

The high level answer is that:

torch.export.export() captures a “higher level” graph, with a few properties:

(1) No decompositions have run yet (so we have not, e.g. decomposed scaled_dot_product_attention into smaller operations)
(2) The output of torch.export.export is a forward-only graph that is safe to serialize, load up later in python, and train on eagerly. It doesn’t actually capture the backward graph ahead of time, but instead you can re-use the eager autograd machinery to train on the generated forward graph when you load your exported program later

aot_export_module (if you pass in the trace_joint=True flag like you specified above) will trace through the autograd engine, giving you a backward graph ahead-of-time.

One consequence is: there are number of operators that we do not have derivative formulas for. Instead, autograd handles them by first decomposing them, and then running the derivative rule for each decomposed operator. So if you want to fully get out a forward + backward graph ahead of time, we will need to decompose any operators that we do not have derivative rules for.

aten.scaled_dot_product_attention is one of these operators. In principle we could write a generic derivative rule for SDPA (cc @drisspg), but instead we have various backend implementations for sdpa that each have their own derivative rule.

In your example, it looks like for your particular input shapes, we aren’t able to use one of the dedicated kernel backends for SDPA (and are thus forced to use the “generic/math” backend, which decomposes SDPA into several smaller kernels like mm/softmax).

If I tweak your inputs to be:

# make q/k/v 4 dimensional
q, k, v = torch.randn(1, 1, 1024, 3, 128, device="cuda", requires_grad=True).unbind(-2)

then we are able to use one of the dedicated kernels, and I get out a different graph:

def forward(self, arg0_1, arg1_1, arg2_1):
    _scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention.default(arg0_1, arg1_1, arg2_1, None, True)
    getitem = _scaled_dot_product_efficient_attention[0]
    getitem_1 = _scaled_dot_product_efficient_attention[1]
    getitem_2 = _scaled_dot_product_efficient_attention[2]
    getitem_3 = _scaled_dot_product_efficient_attention[3];  _scaled_dot_product_efficient_attention = None
    detach = torch.ops.aten.detach.default(getitem);  detach = None
    detach_1 = torch.ops.aten.detach.default(getitem)
    detach_2 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    detach_4 = torch.ops.aten.detach.default(detach_3);  detach_3 = None
    sum_1 = torch.ops.aten.sum.default(getitem);  getitem = None
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
    expand = torch.ops.aten.expand.default(ones_like, [1, 1, 1024, 128]);  ones_like = None
    detach_5 = torch.ops.aten.detach.default(detach_4);  detach_4 = None
    detach_6 = torch.ops.aten.detach.default(detach_5);  detach_5 = None
    detach_7 = torch.ops.aten.detach.default(detach_6);  detach_6 = None
    detach_8 = torch.ops.aten.detach.default(detach_7);  detach_7 = None
    _scaled_dot_product_efficient_attention_backward = torch.ops.aten._scaled_dot_product_efficient_attention_backward.default(expand, arg0_1, arg1_1, arg2_1, None, detach_8, getitem_1, getitem_2, getitem_3, 0.0, [True, True, True, False]);  expand = arg0_1 = arg1_1 = arg2_1 = detach_8 = getitem_1 = getitem_2 = getitem_3 = None
    getitem_4 = _scaled_dot_product_efficient_attention_backward[0]
    getitem_5 = _scaled_dot_product_efficient_attention_backward[1]
    getitem_6 = _scaled_dot_product_efficient_attention_backward[2]
    getitem_7 = _scaled_dot_product_efficient_attention_backward[3];  _scaled_dot_product_efficient_attention_backward = getitem_7 = None
    return (sum_1, getitem_4, getitem_5, getitem_6)
1 Like

Thanks for your detailed explanation.