map_fn allows you to perform an operation in parallel and collect the results.
My use case is I’d like to be able to run several mini supervised learning problems in parallel. In each thread, I would take several gradient steps on the same base model, and return the outputs. Then I do some computation using the outputs to update the base model, and repeat. My data is large enough that I’d like copying the data to the GPU to also be parallelized along with the forward and backward operations.
I’ve looked at DataParallel but this seems to operate at the module level - I don’t see how to have different copies of the model taking different update steps? Elsewhere, I’ve seen that Python multi-processing doesn’t always work well with CUDA?
Thanks for any advice you have!