Using vmap instead of explicit batch dimension: good idea?

I’m in a situation where it’s a lot easier to design a layer without an explicit batch dimension. Is it in general now a good idea to just design an entire NN without batch dimension and then just vmap over the batch dim?

I could see problems with BatchNorm, but everybody hates BatchNorm anyway right?

Based on the docs it seems this might indeed be a targeted use case:

vmap() can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.

However, I would also be interested in learning more about your actual use case and how you are writing your code. I know some ops might not accept a batched input in which case vmap sounds like the right tool, but it seems your code can be simplified by dropping the batch dimension?

1 Like

Yes, simplification is the primary motivation for me. I’m not using any NN layers except nn.Linear (in this case), which works with any shapes, so I figured it would be a nice opportunity to just drop the batch dimension completely.

I was explicitly interested in some CUDA benchmarks since it’s not very clear to me how the various parallelization optimizations happen under the hood… but I guess I could do some simple benchmarks with e.g. nn.Linear myself.

1 Like

Yes, I would also recommend benchmarking the actual use case to see if any differences would be visible or if vmap can properly use the optimal ops.
Also, CC @richard for visibility in case you have already benchmarked different workloads.

1 Like