I’m in a situation where it’s a lot easier to design a layer without an explicit batch dimension. Is it in general now a good idea to just design an entire NN without batch dimension and then just vmap over the batch dim?
I could see problems with BatchNorm, but everybody hates BatchNorm anyway right?
Based on the docs it seems this might indeed be a targeted use case:
vmap() can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.
However, I would also be interested in learning more about your actual use case and how you are writing your code. I know some ops might not accept a batched input in which case vmap sounds like the right tool, but it seems your code can be simplified by dropping the batch dimension?
Yes, simplification is the primary motivation for me. I’m not using any NN layers except nn.Linear (in this case), which works with any shapes, so I figured it would be a nice opportunity to just drop the batch dimension completely.
I was explicitly interested in some CUDA benchmarks since it’s not very clear to me how the various parallelization optimizations happen under the hood… but I guess I could do some simple benchmarks with e.g. nn.Linear myself.
Yes, I would also recommend benchmarking the actual use case to see if any differences would be visible or if vmap can properly use the optimal ops.
Also, CC @richard for visibility in case you have already benchmarked different workloads.