How to set functions as black box or wrap when using torch dynamo to capture backward graph?

I am now using

m, _ = aot_export_module(fwd, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)

to capture forward and backward graph based on aot grad and dynamo.

In the past, we could set a function @torch.fx.wrap to avoid it decomposed to lower ops. ref: torch.fx — PyTorch 2.3 documentation

So I wonder if there are any methods to do the same thing?

One possible way to do it is to set a custom op as:
https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial

But in the page, it said requires Pytorch 2.4, which is unfortunately not available to me.

You should still be able to create custom ops via the lower level torch.library.Library APIs.

lib = torch.library.Library("foo", "FRAGMENT")

lib.define("bar(Tensor x) -> Tensor")

def bar_impl(a):
    return a.clone()

lib.impl("bar", bar_impl, "CPU")
lib.impl("bar", bar_impl, "CUDA")
lib.impl("bar", bar_impl, "Meta")

The newer custom ops APIs are intended to make the process easier and less error prone, so it is still recommended that you upgrade if possible.

1 Like

I understand, thank you.

Update: How to set wrap function using TorchDynamo graph capture? - #4 by wmhst7 - compiler - PyTorch Developer Mailing List