Replicating JAX style code in PyTorch

Are there any advantages to using JAX now that torch.func and torch.compile exist? In other words, what parts of PyTorch are intended to be used together to recreate a JAX-like experience, if that’s possible? There are many compilation options available, and I’d appreciate a clear and succinct answer on the currently recommended option

I’m aware that JAX is intended to replicate NumPy’s API as much as possible, which is not in PyTorch’s scope. I’m seeking a head-to-head comparison without an unrealistic expectation of the possibility of stretching PyTorch’s shape to fit JAX exactly.

If PyTorch currently falls short (for instance, how PyTrees are handled with vmap and gradient), please enumerate the ways so I can make an informed choice about what to use. I’m particularly interested in the debugging story when it comes to compiled code as well as compilation speed and ultimate performance of compiled code relative to JAX.

I’m especially interested in the opinion of someone who’s extensively used both the newer PyTorch features and JAX.