FX `Proxy` and shapes

Hi all,

Thanks for your work on this exciting new feature of PyTorch!

I’m interested in FX for an application that involves graph rewriting based on tensor shapes. The catch is that all of the shapes, except for a batch dimension, are known at “compile” time.

I can see one way of doing this with FX, using Transformer with real Tensors full of zeros and branching in call_function as demonstrated in the docs.

This seems a little inefficient, however: is it possible to attach (partial) shape information to a Proxy and propagate it through the graph, using it to rewrite some function calls in Transformer?


In a similar vein, if I do use ShapeProp to follow dummy Tensors through the graph, is there a way to keep track of which dimensions “are” or derive from the batch dimension? (Short of picking a unique, and thus likely large, number?)

Hi @Linux-cpp-lisp,

This is a great question and one that’s come up several times. The long-and-short of it is that providing this sort of symbolic shape functionality is quite difficult, since it involves writing code or symbolic formulae to calculate the shapes for every op/module you might see in the program. We currently don’t have such formulae for general usage in PyTorch, though we might sometime in the future.

If, however, your transformation can make a simplifying assumption, it may be more tractable to implement for your use case. For example, if your transform only operates on a small set of operators or modules, it may be tenable to implement the formulas for just those ops/modules.

Please let me know if that works for your situation, otherwise we can register this need with the team to better motivate the general shape analysis problem.


Hi @James_Reed ,

Thanks for the quick response!

That makes sense — things seem to have worked out for my specific instance of the problem I was facing such that ShapeProp worked well enough, but for the general case one would definitely need symbolic shape tracing, particularly for keeping track of which dimensions of intermediate results are affected by the dimensions of the inputs.

(So its not really an issue of the efficiency of shape propagation with real Tensors — it’s a more fundamental limitation, in general.)

I’m sure that general symbolic shape analysis for a library the size of PyTorch would be a very painful project of dubious general value. What might be interesting in the future is providing a common repository of symbolic shape formulas somewhere in fx that people can contribute back to over time — basically a small but growing dict of rules an an associated example of a SymbolicShapeProp class. The biggest challenge, as far as I can tell, would be a symbolic version of broadcasting. Once you know how two shapes broadcast, most other shape transformations that I can think of follow simple rules. (Einsum or matmul, for example.)

Anyway, that’s not a feature request, just some thoughts. Really appreciate the work you’ve all been doing on fx — right before it was released and I learned about it I was writing my own very ad-hoc limited version of it, and having this has really simplified things.


One idea I’ve been toying around with is some kind of general solver (i.e. SMT)-based shape inference. It’s unfortunately not that easy to propagate shapes in the general case - broadcasting is one significant example. However, there are lots of other annoying examples as well.

For example, torch.cat([a, b]) results in a shape that’s dynamic with the shapes of both a and b.

A solver-based shape inference would allow you to query whether a shape 1. could be dynamic, and 2. if it were dynamic, what the constraints on its shape would be.

1 Like

I am so happy to see that the issue have been solved Looking for the same info.