Issue when using aot_module with torch.autograd

I want to dump the fx graph when executing the torch.autograd(y, W) of a linear module y = b + x * W^T. But the fx graph in bw_compiler returns a graph of computing every gradients, I only want the code of computing gradient of W:

def forward(self, primals_3, t, tangents_1):
    t_1 = torch.ops.aten.t.default(t);  t = None
    mm = torch.ops.aten.mm.default(tangents_1, t_1);  t_1 = None
    t_2 = torch.ops.aten.t.default(tangents_1)
    mm_1 = torch.ops.aten.mm.default(t_2, primals_3);  t_2 = primals_3 = None
    t_3 = torch.ops.aten.t.default(mm_1);  mm_1 = None
    sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True);  tangents_1 = None
    view = torch.ops.aten.view.default(sum_1, [4]);  sum_1 = None
    t_4 = torch.ops.aten.t.default(t_3);  t_3 = None
    return (t_4, view, mm)

The t_4, view, mm is the gradient of W, b, x, respectively.

I wonder how to generate the code correctly? Here is the complete code:

import torch
from functorch.compile import aot_function, make_boxed_func, aot_module

class MyLinearGrad0(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        return self.linear(x)

def compiler_print(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return make_boxed_func(fx_module)

def main():
    x = torch.randn([2, 4], requires_grad=True)
    linear_grad0_gen = MyLinearGrad0()

    aot_linear_grad0_gen = aot_module(linear_grad0_gen, fw_compiler=compiler_print, bw_compiler=compiler_print)
    x_gen = x.clone().detach().requires_grad_(True)
    y_gen = aot_linear_grad0_gen(x_gen)
    # grad_x = torch.autograd.grad(y_gen.sum(), x_gen)
    grad_w = torch.autograd.grad(y_gen.sum(), linear_grad0_gen.linear.weight)
    # y_gen.sum().backward()


if __name__ == "__main__":
    main()

and pytorch version:

pip show torch
Name: torch
Version: 2.9.1+cpu
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org
Author:
Author-email: PyTorch Team <packages@pytorch.org>
License: BSD-3-Clause
Location: /home/dym/.conda/envs/aot/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: torchvision