How to write parallel for loop in model's forward method?

hi, i currently have a model that intakes a tensor N x C x H x W and a label tensor N x L. In the forward method, I want to first use different net for different label and then combine them together for other layers. so i current have some code like:

def forward(self, data):
    for i in range(label_types_num):                                                  
        idx = get_label_idx(i) # get all the index that has this label
        group_data = data[idx, :]                                               
        o = self.seperate_nets[i](group_data)                                                                                                
        output_tensor[idx, :] = o                                               
    # then output_tensor is passed to other layers...             

Just wondering is there a way to parallel do the for loop?

3 Likes

I think you could use torch.distributed.launch. This snippet may be useful to you to get a better grasp on how torch.distributed.launch works.

1 Like

hi, i take a look at that snippet, and it seems for me that the distribution is done by distributing data onto different GPUs (correct me if i’m wrong). However, here I would like some solutions that parallel the for loop in a single GPU case since each ‘data’ in my code itself is already in a single GPU when calling the forward.

If you analyze more carefully the snippet, you will notice that torch.distributed.launch launches CPU threads. Indeed, you can leverage this to run a model on multiple GPUs (one per thread), but you can do much more, like running multiple threads for one GPU. One limitation of doing this is that, as far as I can tell, the threads don’t share memory, which means that if you launch 10 threads on the same GPU, it will allocate 10x more memory.

3 Likes

thank you so much! that makes sense, i’ll try to find a balance between memory and speed for my model.

Is there any concrete solution to this? An example solution would be very helpful, as the suggested post does not make this use-case explicit.

2 Likes