PyTorch 2.1.0 doesn't support blackbox custom ops through Dynamo graph capture?

I utilize the torch.library to define and implement custom operations. Subsequently, I employ the AOT (Ahead-of-Time) module export feature to capture computational graphs. However, I’ve encountered a discrepancy between versions of PyTorch: in PyTorch 2.1.0, the custom operations are decomposed, whereas in PyTorch 2.3.0, they are not decomposed.

So is it not supported in 2.1.0?

Here are two snippets of code for comparison:

For PyTorch 2.1.0:

import numpy as np
import torch
import torch.nn.functional as F
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition

mylib = torch.library.Library("mylib", "FRAGMENT")

# Define forward op
mylib.define("bar(Tensor x) -> Tensor")

# @torch.library.impl("mylib::bar", "default")
def bar_impl(x):
    return torch.empty_like(x)

mylib.impl("bar", bar_impl)

# Define backward op
mylib.define("bar_backward(Tensor grad, Tensor x) -> Tensor")

def bar_backward(grad, x):
    return torch.empty_like(x)

mylib.impl("bar_backward", bar_backward)

# Create an autograd.Function with the forward and backward
class CustomFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.ops.mylib.bar(x)

    @staticmethod
    def backward(ctx, grad):
        x = ctx.saved_tensors[0]
        return torch.ops.mylib.bar_backward.default(grad, x)

def custom_func(x):
    return CustomFunc.apply(x)


class CustomModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))

    def forward(self, x):
        x = custom_func(x)
        x = torch.mm(x, self.w1)
        # x = F.gelu(x)
        x = custom_func(x)
        x = x.sum()
        return (x,)


if __name__ == "__main__":
    torch.set_default_dtype(torch.bfloat16)
    with torch.device("meta"):
        hidden_size = 1024
        model = CustomModel(hidden_size)
        inp = torch.zeros(2, hidden_size, requires_grad=True)
        m, _ = aot_export_module(model, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
        fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)
        
        fwd.graph.print_tabular()
        bwd.graph.print_tabular()

Result:

opcode         name          target                   args                                kwargs
-------------  ------------  -----------------------  ----------------------------------  ---------------------
placeholder    arg0_1        arg0_1                   ()                                  {}
placeholder    arg1_1        arg1_1                   ()                                  {}
call_function  empty_like    aten.empty_like.default  (arg1_1,)                           {'pin_memory': False}
call_function  mm            aten.mm.default          (empty_like, arg0_1)                {}
call_function  empty_like_1  aten.empty_like.default  (mm,)                               {'pin_memory': False}
call_function  sum_1         aten.sum.default         (empty_like_1,)                     {}
output         output        output                   ([sum_1, arg1_1, empty_like, mm],)  {}
opcode         name          target                   args                     kwargs
-------------  ------------  -----------------------  -----------------------  ---------------------
placeholder    arg1_1        arg1_1                   ()                       {}
placeholder    empty_like    empty_like               ()                       {}
placeholder    mm            mm                       ()                       {}
call_function  empty_like_2  aten.empty_like.default  (mm,)                    {'pin_memory': False}
call_function  t             aten.t.default           (empty_like,)            {}
call_function  mm_1          aten.mm.default          (t, empty_like_2)        {}
call_function  empty_like_3  aten.empty_like.default  (arg1_1,)                {'pin_memory': False}
output         output        output       

For PyTorch 2.3.0:

import numpy as np
import torch
import torch.nn.functional as F
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition

# Define forward op
torch.library.define("mylib::bar", "(Tensor x) -> Tensor")

@torch.library.impl("mylib::bar", "default")
def bar_impl(x):
    return torch.empty_like(x)

# Define backward op
torch.library.define("mylib::bar_backward", "(Tensor grad, Tensor x) -> Tensor")

@torch.library.impl("mylib::bar_backward", "default")
def bar_backward(grad, x):
    return torch.empty_like(x)

# Create an autograd.Function with the forward and backward
class CustomFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.ops.mylib.bar(x)

    @staticmethod
    def backward(ctx, grad):
        x = ctx.saved_tensors[0]
        return torch.ops.mylib.bar_backward(grad, x)

def custom_func(x):
    return CustomFunc.apply(x)


class CustomModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))

    def forward(self, x):
        x = custom_func(x)
        x = torch.mm(x, self.w1)
        x = F.gelu(x)
        x = custom_func(x)
        x = x.sum()
        return (x,)


if __name__ == "__main__":
    torch.set_default_dtype(torch.bfloat16)
    with torch.device("meta"):
        hidden_size = 1024
        model = CustomModel(hidden_size)
        inp = torch.zeros(2, hidden_size, requires_grad=True)
        m, _ = aot_export_module(model, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
        fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)
        
        fwd.graph.print_tabular()
        bwd.graph.print_tabular()

Result:

opcode         name    target             args                                              kwargs
-------------  ------  -----------------  ------------------------------------------------  --------
placeholder    arg0_1  arg0_1             ()                                                {}
placeholder    arg1_1  arg1_1             ()                                                {}
call_function  bar     mylib.bar.default  (arg1_1,)                                         {}
call_function  mm      aten.mm.default    (bar, arg0_1)                                     {}
call_function  gelu    aten.gelu.default  (mm,)                                             {}
call_function  bar_1   mylib.bar.default  (gelu,)                                           {}
call_function  sum_1   aten.sum.default   (bar_1,)                                          {}
output         output  output             ([sum_1, arg0_1, arg1_1, bar, mm, gelu, sum_1],)  {}
opcode         name            target                      args                       kwargs
-------------  --------------  --------------------------  -------------------------  -------------------------------------------------------------
placeholder    arg0_1          arg0_1                      ()                         {}
placeholder    arg1_1          arg1_1                      ()                         {}
placeholder    bar             bar                         ()                         {}
placeholder    mm              mm                          ()                         {}
placeholder    gelu            gelu                        ()                         {}
placeholder    sum_1           sum_1                       ()                         {}
call_function  ones_like       aten.ones_like.default      (sum_1,)                   {'pin_memory': False, 'memory_format': torch.preserve_format}
call_function  expand          aten.expand.default         (ones_like, [2, 1024])     {}
call_function  bar_backward    mylib.bar_backward.default  (expand, gelu)             {}
call_function  gelu_backward   aten.gelu_backward.default  (bar_backward, mm)         {}
call_function  t               aten.t.default              (bar,)                     {}
call_function  mm_1            aten.mm.default             (t, gelu_backward)         {}
call_function  t_1             aten.t.default              (arg0_1,)                  {}
call_function  mm_2            aten.mm.default             (gelu_backward, t_1)       {}
call_function  bar_backward_1  mylib.bar_backward.default  (mm_2, arg1_1)             {}
output         output          output       

Setting dispatch key to CompositeExplicitAutograd in 2.1.0 solved this problem