What would help here is an analogue of Jax’s pmap. I think @ezyang is experimenting with related vmap