Fast shape inference

Hi, I have a use case where I’d like to get the output shape of torch operations.

I could use Fake Mode to do this, but it’s quite slow (the overhead with fake mode for 1000 torch.ones ops is about 200-250ms), which makes this not an option for me. I’ve heard that meta shape inference is also an option, but also that it doesn’t have complete coverage.

Linking against libtorch is also not desired, but if there’s a consistent shape inference story there I’m willing to consider it. It looks like this exists [RFC] A PyTorch Tensor Shape DSL For Symbolic Shape Inference · Issue #54982 · pytorch/pytorch · GitHub and I would love to be able to compute with symbolic shapes for shape inference, but it looks like this work has been dropped? I also see a bunch of shape inference routines in ts_native_functions.cpp, but I don’t see them used anywhere – and it looks like part of the LazyTensor backend. What’s the story there?

So I’m wondering what my options are:

  • write my own shape inference in C++ for the ops I care about. A lot of overhead but I’d only have to do this once
  • bite the bullet and link against libtorch, use the LazyTensor shape inference implementations somehow

Any thoughts? Thanks!