Is jax really 10x faster than pytorch?

I was reading the following post when I cam accross the figure below and I was wondering whether that’s true for jax vs pytorch, since I haven’t been following closesly the developments in this space? Any thoughts?

1 Like