Questions about Pytorch 2.0 and mps

Recently Pytorch had announced Pytorch 2.0, which is awesome

  1. It canonicalizes 2000+ primitives to 250+ essential ops, and 750+ ATen ops. Which reduce the implementation code by at least about a half
  2. It use TorchDynamo which improves graph acquisition time
  3. It uses faster code generation through TorchInductor

However, as my understanding goes mps support is in very early stage, and according to General MPS op coverage tracking issue there seems to be a lot more Aten Ops to be implemented.

So did Pytorch 2.0 change any codebase or direction that had been going since Pytorch 1.12 (I meant is TorchDynamo compatible with current mps graph acquisition methods?)?

Also TorchInductor will compile into Triton, which seems to only support CUDA (, and maybe ROCm in the future). If Pytorch 2.0 came along, how this going to affected the mps support?

Sorry if I misunderstand something, I just want to know the general direction. Thanks

As of today the default backend for torch.compile on MPS supported devices is aot_eager so you’d do torch.compile(..., backend="aot_eager")

Inductor support for MPS might eventually happen but it’s primarily dependent on Triton supporting MPS. Inductor codegens triton kernels which then can run on GPU like devices, right now the support in Triton is mostly NVIDIA GPU focused so . You can make the ask here GitHub - openai/triton: Development repository for the Triton language and compiler

Also the current MPS implementation is an eager mode implementation so it’s not dependent on graph capturee but generally compiler and hardware authors like graphs because there’s a well established litterature on how to optimize them

1 Like