Equivalent to TF's "map_fn" in PyTorch?

map_fn allows you to perform an operation in parallel and collect the results.

My use case is I’d like to be able to run several mini supervised learning problems in parallel. In each thread, I would take several gradient steps on the same base model, and return the outputs. Then I do some computation using the outputs to update the base model, and repeat. My data is large enough that I’d like copying the data to the GPU to also be parallelized along with the forward and backward operations.

I’ve looked at DataParallel but this seems to operate at the module level - I don’t see how to have different copies of the model taking different update steps? Elsewhere, I’ve seen that Python multi-processing doesn’t always work well with CUDA?

Thanks for any advice you have!

I’m not sure I understood it properly. but one solution would be to adapt the current implementation of nn.DataParallel so that you can better control when you gather the parameters.
For example, have a look at those lines. I believe you can adapt it somewhat easily so that you only replicate the model and gather the results from time to time.

I see how I can use that to compute the forward passes for the different inputs. But can I also do backward passes?

Those functions automatically handle the backward pass for you