XLA AOT with PyTorch

Hi,

I saw PyTorch has support for XLA backend.
Is it possible to do ahead-of-time compilation of PyTorch model with XLA like for Tensorflow with tfcompile?

Do you have a hard requirement for using XLA to generate code? Otherwise you might want to check out torch.compile: torch.compile Tutorial — PyTorch Tutorials 1.13.1+cu117 documentation

The classic way of PyTorch/XLA is through a mechanism called Lazy tensor, you can checkout more details under xla/API_GUIDE.md at master · pytorch/xla · GitHub. What happened is we will still execute the python code but we will build a graph and compile/execute the graph upon a mark_step api call. We have a cache so you we only compile the same graph once(we need to trace graph for every step, but that can be overlapped with previous steps’s execution).

The alternate is our new TorchDynamo bridge(which is the torch.compile), you can check out more details in TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training - compiler - PyTorch Dev Discussions. The current status is that inference works great and we are improving the training bridge.

Thank you for your answers, I feel like I should have added some context.
I am working on integrating a new kind of ML inference hardware accelerator with existing ML libraries and a possible path is to develop a new backend for XLA.
Depending on our customer application (I am thinking about edge devices here), we might want to be able to generate low overhead code.
From what I read XLA has the benefit of being integrated with the 3 current major ML libraries and provide AOT compilation at least for tensorflow. If XLA AOT is only available for tensorflow models, that is something to consider.
The other solution I am considering is TVM.

Thanks for the context, I think what you are looking for is AOTAutograd AOT Autograd - How to use and optimize? — functorch 1.13 documentation. PyTorch/XLA has not test the integration with the AOTAutograd because we think most people will direcly use torch dynamo which integrate the AOTAutograd stack as well. It shouldn’t be too hard for PyTorch/XLA to support that through. AOT compile and execute is very natural for the XLA stack.

So if I understand correctly there’s an AOT feature in PyTorch and it might already integrate with XLA however no integration test has been done yet to confirm it?

I think you can play with AOT Autograd a bit(without XLA) and see if it gives the feature you want. If that’s the case you can file a feature request to Issues · pytorch/xla · GitHub and we will see what does it take for us to support it.

1 Like

Thank you very much for the quick help and support :slight_smile:

1 Like