Fast way to use `map` in PyTorch?

What would help here is an analogue of Jax’s pmap. I think @ezyang is experimenting with related vmap