Looking at the vmap documentation and just playing around with the function, it looks like you can map many models to many inputs as long as num_models == num_inputs i.e. the batch sizes are the same. My question is, what if there are N*M inputs, where N is the number of models. Say, for example, I have 10 models and 20 inputs. Can I vmap the models to the inputs such that each model does forward pass on two of the inputs?