I’ve been digging around torch.comile and torchdynamo but couldn’t find exact line(code) where make_fx checks if op is instance of torch.ops.OpOverload? (or converting higher level op to OpOverload). It would be really helpful if anyone could point to such code line
make_fx() doesn’t directly check if an op is an instance of a torch.ops.OpOverload. Instead, it relies on tracing through all code in torch.* (and the pytorch dispatcher), basically the same as if you were running your code in eager mode.
In eager mode, there are a couple of steps that happen before we execute a kernel. They look roughly like:
(1) Run any python code (e.g. all of the code in torch.nn.MaxPool2d*)
(2) Enter the python ↔ C++ boundary (e.g torch.max_pool2d)
(3) Enter the PythonArgParser, performing overload resolution (this is normally where we desugar from a torch op to an individual “OpOverload”. e.g. `aten.max_pool2d.default)
(4) For that OpOverload, enter the pytorch dispatcher
(5) For that OpOverload, execute it’s cpu/cuda kernel
Also - if you happen to build pytorch from source, and you build with the DEBUG=1 build flag, then we have a handy runtime env-var that you can use to print every dispatch key that gets hit:
def f(x):
out = torch.mul(x, 2)
out.sum().backward()
return out, x.grad
x = torch.ones(2, requires_grad=True)
gm = make_fx(f)(x)
# If you have a debug build, you can try running:
# TORCH_SHOW_DISPATCH_TRACE=1 python script.py
I have another question regarding PythonArgParser(desugaring part). My use case is that I have a custom function(empty function that returns void) that has been included in the graphmodule via torch._dynamo.allow_in_graph but disappears after running make_fx. I assume that this is natural in the sense that functions added via allow_in_graph are still traced through and since my custom function is empty, it is omitted. Does this exclusion happen in the PythonArgParser or somewhere else?