What functions like torch._C._DisablePythonDispatcher do?

While learning the source code of make_fx I encounter three functions:

no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch

I guess they might be a set of important functions related to the dispatcher. After searching the codebase, I still can’t summarize how they worked exactly. Is there anyone can provide some explanation of these functions?


These are bound to C++ RAII guards: https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/init.cpp#L447 (all 3 are in this file if you grep for them).

These guards are similar to context managers in python: they temporarily set some global (thread-local) state on entry, and unset the state on exit. In all 3 of these cases, that “state” maps to a Dispatch key (if you’re interested in what dispatch keys are and how the pytorch dispatcher works, check out Ed’s blogpost here: Let’s talk about the PyTorch dispatcher : ezyang’s blog).

In the snippets you mentioned above, there are two relevant dispatch keys, that affect the behavior of make_fx. I’ll cover them in a bit more detail.

DispatchKey::PythonDispatcher (here: https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L427)

The “python dispatcher” is a piece of infra that lets us override the functionality of every existing dispatch key, from python. It’s not user facing, but it has been very convenient for adding functionality in core. For example, we have C++ decompositions for many operations that are run in the dispatcher normally in eager mode, and we’ve used the python dispatcher to override some of them. Example: the matmul decomp here: https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L3852. When we capture a graph with torch.compile (with the default inductor backend), the python dispatcher is always enabled, allowing these decomps to run. It’s also allowed us to add stuff like debugging asserts when functionalization runs (here: https://github.com/pytorch/pytorch/blob/main/torch/_ops.py#L708)

DispatchKey::PreDispatch (here: https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L423)

Normally when you run gm = make_fx(f), you will capture a graph of ATen ops that traces entirely through the dispatcher. This means that every “functionality” baked into the dispatcher will be traced through in the resulting graph: autograd, autocast, functorch transforms, decomps that live in the dispatcher, and several others.

gm = make_fx(f, pre_dispatch=True) allows you to also capture an ATen graph, but the graph capture happens at the “top” of the dispatcher (at the PreDispatch dispatch key, which is near the top of the dispatcher’s key ordering). There are a few use cases for this, but one is that the resulting graph you capture out can be re-used with the eager autograd engine (this is not necessarily true for a vanilla make_fx graph).

Thank you for your reply. Your insights are valuable to me.