Hi @hadarshavit,
Have a look at the torch.func
library, you can functionalize your model to take the params as an input. You can then vmap over them.
An example can be found on the forums here: Another way to implement MAML(Model-Agnostic Meta-Learning)?
Although, you’ll need to replace functorch
with torch.func
for pytorch2.0 and above.