Torch.vmap() and batchNorm

I enjoyed using the unreleased torch.vmap() particularly because I find it canonical to author models without thinking about a batch dimension. When modules like batchNorm1d are involved this is impossible right?

Are there any workarounds or plans to address this issue?