Question about vmap many models to many inputs

Looking at the vmap documentation and just playing around with the function, it looks like you can map many models to many inputs as long as num_models == num_inputs i.e. the batch sizes are the same. My question is, what if there are N*M inputs, where N is the number of models. Say, for example, I have 10 models and 20 inputs. Can I vmap the models to the inputs such that each model does forward pass on two of the inputs?