I’d like to take an input image, create dozens or hundreds of patches it, and run each patch through potentially dozens of small MLPs (<1000 parameters). The patch selection is done deterministically, which I’m less worried about. The MLPs, however, currently execute serially in a for-loop which is tragically slow.
Is there native support for parallelization of many (small) NNs?
It’s in the latest nightly release of pytorch or installed independently with the latest stable.
You should retrieve the params of the modules using make_functional().
Then you can stack your parameters across NNs and use vmap to dispatch the data and params to the functional module.
Have a look at vmap here