MAML inner loop parallelization

I want to parallelize the inner loop of MAML.
Each inner loop of the MAML will produce individual loss along with individual gradient graphs,
and after the iteration, I have to aggregate the losses followed by backpropagation.

My naive idea is replacing the loop to map.
To do this, I guess I need to aggregate the loss from multiple threads.
(e.g. torch.mean(torch.stack(list_of_loss_from_multiple_threads))

Is it possible to aggregate graphs from worker threads and then do the backprop at once?


Hi, it’s pretty tough to give any concrete advice without first knowing what exactly you’re doing and what you’ve tried. Would you mind posting a snippet of code that indicates the inner loop that you’d like to parallelize, and any instructions needed to run the code? Thanks!