Registering custom operations and CF for FX Graphs

Hi all, I am working for Catalyst GitHub - PennyLaneAI/catalyst: A JIT compiler for hybrid quantum programs in PennyLane a Jit compiler for hybrid quantum programs. We have a pipeline that supports Jax, where we define quantum operations as Jax primitives, this allows us to lower to JaxPr and then to stable hlo and our own MLIR quantum dialect. We are now considering to add support for PyTorch and I have the following questions:

  1. Similarly to Jax primitives is there a way to extend PyTorch with custom operations and lower them to FX graph operations?
  2. I know that Dynamo is supporting control flow by splitting into multiple subgraphs. Are there native Torch cf operations or is there a way to add control flow operations to Fx graphs?

Thanks!

I think the closest terminology in inductor is lowerings which you can check out here https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py - I’m not sure this has a registration API quite yet but it’s a reasonable ask

Regarding your second question around control flow there’s been some recent work adding a torch.cond operation, you can click around files changed here to get some sense as to how they work Improve torch.cond useability: Return UserError with actionable error messages by guangy10 · Pull Request #98909 · pytorch/pytorch · GitHub

That said considering most of your stack includes stable HLO already, if you wait a couple of weeks we’ll probably soon see something like torch.export(m, backend="stable_hlo") be released, the public API for export() is getting merged soon Expose torch.export() and utilities by gmagogsfm · Pull Request #106242 · pytorch/pytorch · GitHub

1 Like

@marksaroufim Thanks a lot for your answer! Is there not a way to add an operation to torch and compile the FX graph only by using the FX module and not inductor?

Ideally I would imagine:

#register my custom operation

def func(x, y):
    x = torch.sin(x)
    y = torch.ops.my_operation(y)
    return torch.sum(x, y)

fx_graph = make_fx(f)(torch.tensor(3), torch.tensor(3))

Yeah that should work just fine, you can develop a custom backend for torch.compile here Custom Backends — PyTorch 2.0 documentation cc @SherlockNoMad who has been building out most of the infra here

Thanks for your answer. Correct me if I am wrong but the backend comes into play after building the FX graph?