Exact code line for checking if node op is instance of torch.ops.OpOverload

Dear community,

I’ve been digging around torch.comile and torchdynamo but couldn’t find exact line(code) where
make_fx checks if op is instance of torch.ops.OpOverload? (or converting higher level op to OpOverload). It would be really helpful if anyone could point to such code line

Hey!

make_fx() doesn’t directly check if an op is an instance of a torch.ops.OpOverload. Instead, it relies on tracing through all code in torch.* (and the pytorch dispatcher), basically the same as if you were running your code in eager mode.

In eager mode, there are a couple of steps that happen before we execute a kernel. They look roughly like:
(1) Run any python code (e.g. all of the code in torch.nn.MaxPool2d*)
(2) Enter the python ↔ C++ boundary (e.g torch.max_pool2d)
(3) Enter the PythonArgParser, performing overload resolution (this is normally where we desugar from a torch op to an individual “OpOverload”. e.g. `aten.max_pool2d.default)
(4) For that OpOverload, enter the pytorch dispatcher
(5) For that OpOverload, execute it’s cpu/cuda kernel

make_fx() tracing involves running all of the same code. However, right before getting to the cpu/cuda kernel, we jump to make_fx’s __torch_dispatch__ code, which traces each op into a graph (here: https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L600)


Here are some hopefully-helpful related readings:

A guide stepping slowly through python bindings + the C++ dispatcher for a single operator (torch.mul): PyTorch dispatcher walkthrough · pytorch/pytorch Wiki · GitHub

Horace’s post about __torch_dispatch__: What (and Why) is __torch_dispatch__? - frontend API - PyTorch Dev Discussions

Ed’s blog about the dispatcher: Let’s talk about the PyTorch dispatcher : ezyang’s blog.

Also - if you happen to build pytorch from source, and you build with the DEBUG=1 build flag, then we have a handy runtime env-var that you can use to print every dispatch key that gets hit:

def f(x):
    out = torch.mul(x, 2)
    out.sum().backward()
    return out, x.grad

x = torch.ones(2, requires_grad=True)
gm = make_fx(f)(x)

# If you have a debug build, you can try running:
# TORCH_SHOW_DISPATCH_TRACE=1 python script.py
1 Like

Thanks so much for detailed explanation!

I have another question regarding PythonArgParser(desugaring part). My use case is that I have a custom function(empty function that returns void) that has been included in the graphmodule via torch._dynamo.allow_in_graph but disappears after running make_fx. I assume that this is natural in the sense that functions added via allow_in_graph are still traced through and since my custom function is empty, it is omitted. Does this exclusion happen in the PythonArgParser or somewhere else?

Thanks in advance.

If you want your custom function to make it through make_fx, then you will need to register it as a custom operator to the pytorch dispatcher.

(cc @richard - do we have any docs yet for python-only custom ops?)

In the meantime, there are also some test cases that you can look at with examples: https://github.com/pytorch/pytorch/blob/main/test/test_custom_ops.py#L1552.

1 Like