Parallelizing loop over modules


I have a post-step hook attached to my optimizer that loops over each linear layer in the model. For each linear layer, it calls a method called clip_weights. I feel like doing this sequentially can be quite expensive when the number of layers become high (in the 500s), and the layers also become large.

What would be the best approach to parallelize this? Each clip_weights is still intensive as it calculates the std and clamps on each weight, so I guess it has to be done in GPU.

Here is a rough sketch of what I was thinking about:

import torch.multiprocessing as mp
from torch.multiprocessing import Pool

def worker(layer):

def parallel_clip_weights(layers_gen):
    with Pool(processes=mp.cpu_count()) as pool:
        # Use imap or imap_unordered
        # imap returns the results in the order the inputs were given
        # imap_unordered returns the results as soon as they are ready, not necessarily in order
        for _ in pool.imap(worker, layers_gen):

# Assuming layers is a generator
# layers_gen = model.layers()

But I am not sure what the best approach is.
Any suggestions?