Originally, I used torch2.4 version to export a llama model with torch.export() and it will get decomposed version of torch.nn.functional.scaled_dot_product_attention like below
transformers/models/llama/modeling_llama.py:603 in forward, code: attn_output = torch.nn.functional.scaled_dot_product_attention(
unsqueeze_8: "f32[1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(mul, 0); mul = None
unsqueeze_9: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(unsqueeze_8, 1); unsqueeze_8 = None
slice_16: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 2, 0, 9223372036854775807); unsqueeze_9 = None
slice_17: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_16, 3, 0, 9223372036854775807); slice_16 = None
expand_5: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.expand.default(slice_17, [sym_size_int_1, 1, -1, -1]); slice_17 = None
slice_18: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(expand_5, 0, 0, 9223372036854775807); expand_5 = None
slice_19: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_18, 1, 0, 9223372036854775807); slice_18 = None
slice_20: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_19, 2, 0, 9223372036854775807); slice_19 = None
slice_21: "f32[s0, 1, 1, s2 + 1]" = torch.ops.aten.slice.Tensor(slice_20, 3, 0, sym_size_int_3); slice_20 = None
mul_9: "f32[s0, 32, 1, 128]" = torch.ops.aten.mul.Scalar(add_4, 0.29730177875068026); add_4 = None
transpose_7: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.transpose.int(_unsafe_view, -2, -1); _unsafe_view = None
mul_10: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.mul.Scalar(transpose_7, 0.29730177875068026); transpose_7 = None
expand_6: "f32[s0, 32, 1, 128]" = torch.ops.aten.expand.default(mul_9, [sym_size_int_1, 32, 1, 128]); mul_9 = None
mul_11: "Sym(32*s0)" = sym_size_int_1 * 32
view_13: "f32[32*s0, 1, 128]" = torch.ops.aten.view.default(expand_6, [mul_11, 1, 128]); expand_6 = None
expand_7: "f32[s0, 32, 128, s2 + 1]" = torch.ops.aten.expand.default(mul_10, [sym_size_int_1, 32, 128, sym_size_int_3]); mul_10 = None
view_14: "f32[32*s0, 128, s2 + 1]" = torch.ops.aten.view.default(expand_7, [mul_11, 128, sym_size_int_3]); expand_7 = mul_11 = None
bmm_1: "f32[32*s0, 1, s2 + 1]" = torch.ops.aten.bmm.default(view_13, view_14); view_13 = view_14 = None
view_15: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.view.default(bmm_1, [sym_size_int_1, 32, 1, sym_size_int_3]); bmm_1 = None
add_6: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.add.Tensor(view_15, slice_21); view_15 = slice_21 = None
_softmax: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten._softmax.default(add_6, -1, False); add_6 = None
expand_8: "f32[s0, 32, 1, s2 + 1]" = torch.ops.aten.expand.default(_softmax, [sym_size_int_1, 32, 1, sym_size_int_3]); _softmax = None
mul_12: "Sym(32*s0)" = sym_size_int_1 * 32
view_16: "f32[32*s0, 1, s2 + 1]" = torch.ops.aten.view.default(expand_8, [mul_12, 1, sym_size_int_3]); expand_8 = sym_size_int_3 = None
expand_9: "f32[s0, 32, s2 + 1, 128]" = torch.ops.aten.expand.default(_unsafe_view_1, [sym_size_int_1, 32, sym_size_int_5, 128]); _unsafe_view_1 = None
view_17: "f32[32*s0, s2 + 1, 128]" = torch.ops.aten.view.default(expand_9, [mul_12, sym_size_int_5, 128]); expand_9 = mul_12 = sym_size_int_5 = None
bmm_2: "f32[32*s0, 1, 128]" = torch.ops.aten.bmm.default(view_16, view_17); view_16 = view_17 = None
view_18: "f32[s0, 32, 1, 128]" = torch.ops.aten.view.default(bmm_2, [sym_size_int_1, 32, 1, 128]); bmm_2 = None
transpose_8: "f32[s0, 1, 32, 128]" = torch.ops.aten.transpose.int(view_18, 1, 2); view_18 = None
transpose_9: "f32[s0, 32, 1, 128]" = torch.ops.aten.transpose.int(transpose_8, 1, 2); transpose_8 = None
But now I use torch2.6/2.7, torch.nn.functional.scaled_dot_product_attention is not decomposed any more like below
unsqueeze_10: "f32[1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(mul_9, 0); mul_9 = None
unsqueeze_11: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.unsqueeze.default(unsqueeze_10, 1); unsqueeze_10 = None
slice_22: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 2, 0, 9223372036854775807); unsqueeze_11 = None
slice_23: "f32[1, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_22, 3, 0, 9223372036854775807); slice_22 = None
expand_4: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.expand.default(slice_23, [sym_size_int_14, 1, -1, -1]); slice_23 = None
slice_24: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(expand_4, 0, 0, 9223372036854775807); expand_4 = None
slice_25: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_24, 1, 0, 9223372036854775807); slice_24 = None
slice_26: "f32[s0, 1, 1, s2 + 2]" = torch.ops.aten.slice.Tensor(slice_25, 2, 0, 9223372036854775807); slice_25 = None
slice_27: "f32[s0, 1, 1, s2 + 1]" = torch.ops.aten.slice.Tensor(slice_26, 3, 0, add_4); slice_26 = add_4 = None
scaled_dot_product_attention: "f32[s0, 32, 1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_133, _unsafe_view, _unsafe_view_1, slice_27); add_133 = _unsafe_view = _unsafe_view_1 = slice_27 = None
Does anybody know what is the reason and is there any switch option to change the behavior between these two? Thanks a lot!