Is there a way to parallelize meta-update for MAML?

Hi, I’m trying to make my MAML code more efficient by parallelizing the meta-update. Suppose I have a meta batch size of b, then I have to use a for-loop to iterate through all b train-test pairs for one meta-update. However, in Tensorflow, you could use tf.map_fn to parallelize the for loop and save a lot of time. I’m wondering how I could do this in PyTorch. Thanks!