I have multiple MLPs with the same architecture but different weights.
I want to get the result from all of those MLPs on the same data (tensor of size [B, N], B - number of samples, N - number of input features).
How can I vectorized the inference on those models?
Hi @hadarshavit,
Have a look at the torch.func
library, you can functionalize your model to take the params as an input. You can then vmap over them.
An example can be found on the forums here: Another way to implement MAML(Model-Agnostic Meta-Learning)?
Although, you’ll need to replace functorch
with torch.func
for pytorch2.0 and above.