Active learning and distributed data parallel


I recently started looking into using pytorch for an active learning project and have some queries in using distributed data parallel. As per my understanding, distributed data parallel creates a process per gpu to run the model. Active learning approaches usually have two stages: annotation and training. In the annotation stage, I need to run the model on the unannotated samples and get annotations for selected samples and in the training stage, the model needs to be trained on the annotated dataset. I am trying to use DistributedDataParallel but am not sure how to.

My workflow will roughly look like:

  1. Run the model on unannotated samples (across all gpus)
  2. rank samples using the output from previous step based on some criterion (main worker, rank 0)
  3. add new samples to the training dataset (all workers’ data loaders need to be updated) and then run training.

This cycle will repeat till a budget is hit (that budget needs to be synced across all processes in order to be able to terminate them).

I am not clear on how I get the output of the model from all processes (step 1) back into the main process, run things like the sample ranking (step 2) on the main process while the other processes wait for the input from the main process before starting the training cycle again. Also, how do I sync variables like budget across processes? It would be great if I can get a pointer to a resource which helps me understand this more. I have looked at the imagenet training script in the github repo but that didn’t help me understand this process.


I have been looking into the distributed operations defined here. Are operations like gather and reduce blocking? For example, after running the model on unlabeled data, I can call gather to get them on a single host. If I do that, do the other processes get blocked as well. The main process then needs to select samples for the labeled set and then broadcast them to all the processes. How do I block the other processes to receive this broadcast before beginning the training cycle?

The collective communications are designed for tightly coupled processes, to give you a better idea of how they work, I borrowed some images from mpi:
Each collective communication primitive is a blocking process for all processes in your created processgroup (whether it is the global WORLD group or a sub group of the global group).

I would suggest you do it in the following way:

import torch.distributed as dist

# make sure all processes in your_group will run this step
model = (model, ..., output_device=your_device, process_group=your_group)
your input per process= ...

# output is only visible to that process
# parameters are synchronized behind the stage
output = model(your_input) 

# make sure all processes in your_group will run this step
dist.gather(..., dst=0)
# perform ranking on process 0
dist.scatter(..., src=0)

# all workers are now synchronized

One crutial thing to notice: the receive buffer (tensor) must be equal or larger than the tensor you have sent, otherwise nasty errors will be thrown.

Thank you very much for the images! I was confused with some of the collective operations and that image really helped clear things up.