Below is my custom torch.compiler backend. I hope aot_autograd could export the backward graph after the forward graph passes. I may replace the op in forward graph, which has been registered through torch.library.custom_op(),and bind it to the backward op. But I found that after I modify the forward graph, the exported backward graph does not change with my modification. Is there any way to achieve my expectation?
import torch
import functools
from torch._dynamo import register_backend
from torch._dynamo.backends.common import aot_autograd, device_from_inputs
from functorch.compile import make_boxed_func, min_cut_rematerialization_partition
# torch._dynamo.config.suppress_errors = True
from .passes import *
@register_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs):
def my_compiler_forward(gm, example_inputs):
torch_pass(gm)
def exec_fw(*i_args):
return gm.forward(*i_args)
return make_boxed_func(exec_fw)
def my_compiler_backward(gm, example_inputs):
def exec_bw(*i_args):
return gm.forward(*i_args)
return make_boxed_func(exec_bw)
return aot_autograd(
fw_compiler=my_compiler_forward,
bw_compiler=my_compiler_backward,
)(
gm, example_inputs
)