How to make TorchDynamoc decompose torch.nn.functional.scaled_dot_product_attention?

Originally, I used torch2.4 version to export a llama model with torch.export() and it will get decomposed version of torch.nn.functional.scaled_dot_product_attention like below

transformers/models/llama/modeling_llama.py:603 in forward, code: attn_output = torch.nn.functional.scaled_dot_product_attention(

    unsqueeze_8: "f32[1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(mul, 0);  mul = None
    unsqueeze_9: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, 1);  unsqueeze_8 = None
    slice_16: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807);  unsqueeze_9 = None
    slice_17: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_16, 3, 0, 9223372036854775807);  slice_16 = None
    expand_5: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.expand.default(slice_17, [sym_size_int_1, 1, -1, -1]);  slice_17 = None
    slice_18: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(expand_5, 0, 0, 9223372036854775807);  expand_5 = None
    slice_19: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_18, 1, 0, 9223372036854775807);  slice_18 = None
    slice_20: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_19, 2, 0, 9223372036854775807);  slice_19 = None
    slice_21: "f32[s0, 1, 1, s2 + 1]" = torch.ops.aten.slice.Tensor(slice_20, 3, 0, sym_size_int_3);  slice_20 = None
    mul_9: "f32[s0, 32, 1, 128]" = torch.ops.aten.mul.Scalar(add_4, 0.29730177875068026);  add_4 = None
    transpose_7: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.transpose.int(_unsafe_view, -2, -1);  _unsafe_view = None
    mul_10: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.mul.Scalar(transpose_7, 0.29730177875068026);  transpose_7 = None
    expand_6: "f32[s0, 32, 1, 128]" = torch.ops.aten.expand.default(mul_9, [sym_size_int_1, 32, 1, 128]);  mul_9 = None
    mul_11: "Sym(32*s0)" = sym_size_int_1 * 32
    view_13: "f32[32*s0, 1, 128]" = torch.ops.aten.view.default(expand_6, [mul_11, 1, 128]);  expand_6 = None
    expand_7: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.expand.default(mul_10, [sym_size_int_1, 32, 128, sym_size_int_3]);  mul_10 = None
    view_14: "f32[32*s0, 128, s2 + 1]" = torch.ops.aten.view.default(expand_7, [mul_11, 128, sym_size_int_3]);  expand_7 = mul_11 = None
    bmm_1: "f32[32*s0, 1, s2 + 1]" = torch.ops.aten.bmm.default(view_13, view_14);  view_13 = view_14 = None
    view_15: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.view.default(bmm_1, [sym_size_int_1, 32, 1, sym_size_int_3]);  bmm_1 = None
    add_6: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.add.Tensor(view_15, slice_21);  view_15 = slice_21 = None
    _softmax: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten._softmax.default(add_6, -1, False);  add_6 = None
    expand_8: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.expand.default(_softmax, [sym_size_int_1, 32, 1, sym_size_int_3]);  _softmax = None
    mul_12: "Sym(32*s0)" = sym_size_int_1 * 32
    view_16: "f32[32*s0, 1, s2 + 1]" = torch.ops.aten.view.default(expand_8, [mul_12, 1, sym_size_int_3]);  expand_8 = sym_size_int_3 = None
    expand_9: "f32[s0, 32, s2 + 1, 128]" = torch.ops.aten.expand.default(_unsafe_view_1, [sym_size_int_1, 32, sym_size_int_5, 128]);  _unsafe_view_1 = None
    view_17: "f32[32*s0, s2 + 1, 128]" = torch.ops.aten.view.default(expand_9, [mul_12, sym_size_int_5, 128]);  expand_9 = mul_12 = sym_size_int_5 = None
    bmm_2: "f32[32*s0, 1, 128]" = torch.ops.aten.bmm.default(view_16, view_17);  view_16 = view_17 = None
    view_18: "f32[s0, 32, 1, 128]" = torch.ops.aten.view.default(bmm_2, [sym_size_int_1, 32, 1, 128]);  bmm_2 = None
    transpose_8: "f32[s0, 1, 32, 128]" = torch.ops.aten.transpose.int(view_18, 1, 2);  view_18 = None
    transpose_9: "f32[s0, 32, 1, 128]" = torch.ops.aten.transpose.int(transpose_8, 1, 2);  transpose_8 = None

But now I use torch2.6/2.7, torch.nn.functional.scaled_dot_product_attention is not decomposed any more like below

    unsqueeze_10: "f32[1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(mul_9, 0);  mul_9 = None
    unsqueeze_11: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(unsqueeze_10, 1);  unsqueeze_10 = None
    slice_22: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807);  unsqueeze_11 = None
    slice_23: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_22, 3, 0, 9223372036854775807);  slice_22 = None
    expand_4: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.expand.default(slice_23, [sym_size_int_14, 1, -1, -1]);  slice_23 = None
    slice_24: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(expand_4, 0, 0, 9223372036854775807);  expand_4 = None
    slice_25: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_24, 1, 0, 9223372036854775807);  slice_24 = None
    slice_26: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_25, 2, 0, 9223372036854775807);  slice_25 = None
    slice_27: "f32[s0, 1, 1, s2 + 1]" = torch.ops.aten.slice.Tensor(slice_26, 3, 0, add_4);  slice_26 = add_4 = None
    scaled_dot_product_attention: "f32[s0, 32, 1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_133, _unsafe_view, _unsafe_view_1, slice_27);  add_133 = _unsafe_view = _unsafe_view_1 = slice_27 = None

Does anybody know what is the reason and is there any switch option to change the behavior between these two? Thanks a lot!

Hi @Weiliang_Lin. I think since pytorch 2.6 we’ve switched to a different level of IR in the default torch.export(). To decompose the sdpa op, please call a separate method .run_decompositions() on ExportedProgram you got.