I want to dump the fx graph when executing the torch.autograd(y, W) of a linear module y = b + x * W^T. But the fx graph in bw_compiler returns a graph of computing every gradients, I only want the code of computing gradient of W:
def forward(self, primals_3, t, tangents_1):
t_1 = torch.ops.aten.t.default(t); t = None
mm = torch.ops.aten.mm.default(tangents_1, t_1); t_1 = None
t_2 = torch.ops.aten.t.default(tangents_1)
mm_1 = torch.ops.aten.mm.default(t_2, primals_3); t_2 = primals_3 = None
t_3 = torch.ops.aten.t.default(mm_1); mm_1 = None
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view = torch.ops.aten.view.default(sum_1, [4]); sum_1 = None
t_4 = torch.ops.aten.t.default(t_3); t_3 = None
return (t_4, view, mm)
The t_4, view, mm is the gradient of W, b, x, respectively.
I wonder how to generate the code correctly? Here is the complete code:
import torch
from functorch.compile import aot_function, make_boxed_func, aot_module
class MyLinearGrad0(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
def compiler_print(fx_module: torch.fx.GraphModule, _):
print(fx_module.code)
return make_boxed_func(fx_module)
def main():
x = torch.randn([2, 4], requires_grad=True)
linear_grad0_gen = MyLinearGrad0()
aot_linear_grad0_gen = aot_module(linear_grad0_gen, fw_compiler=compiler_print, bw_compiler=compiler_print)
x_gen = x.clone().detach().requires_grad_(True)
y_gen = aot_linear_grad0_gen(x_gen)
# grad_x = torch.autograd.grad(y_gen.sum(), x_gen)
grad_w = torch.autograd.grad(y_gen.sum(), linear_grad0_gen.linear.weight)
# y_gen.sum().backward()
if __name__ == "__main__":
main()
and pytorch version:
pip show torch
Name: torch
Version: 2.9.1+cpu
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org
Author:
Author-email: PyTorch Team <packages@pytorch.org>
License: BSD-3-Clause
Location: /home/dym/.conda/envs/aot/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: torchvision