Multi-process inference

I’m looking for a way to do inference on multiple GPUs for an application where inference speed is critical. I have an upstream process that delivers images to a vision model in batches of 50-100. I have tried two ways of splitting the batches up so that each worker gets a different partition, but neither has been fully successful. I also messed around with DistributedDataLoader and other ideas, but this is as far as I have gotten so far. I would appreciate advice or suggestions.

METHOD 1 pseudo-code.
This method works but only gives a 2X increase in speed with 4 GPUs and seems to spawn more processes than it should; there is a ‘print’ statement after the imports at the top of my script and it gets called 16 or more times.

world_size = comm.get_world_size()
rank = comm.get_local_rank()
model = DistributedDataParallel(model, device_ids=[rank], broadcast_buffers=False)

for ib in image_batches:
	#Upstream process delivers a batch of 50 to 100 photos
	batch = get_image_batch()
	inference_dataset = DatasetFromList(batch)

	#Create a new sampler and dataloader for this particular batch
	inferencesampler = torch.utils.data.distributed.DistributedSampler(
		inference_dataset,
		num_replicas=4,
		rank=rank
		)

    testloader = torch.utils.data.DataLoader(
	    dataset=inference_dataset,
	    batch_size=bs,
	    shuffle=False,            
	    num_workers=4,
	    pin_memory=True,
	    collate_fn=trivial_batch_collator,
	    sampler=inferencesampler)     

    #Run the batch
    with torch.no_grad():
        for idx, inputs in enumerate(data_loader):
			outputs = model(inputs)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

METHOD 2 pseudo-code. This method uses an IterableDataset in attempt to keep the dataloader outside the batch loop. Pytorch prohibits you from re-setting the attribute values of a dataloader, which makes it impossible to switch out the dataset on the fly. To get around that, I added a reset function to the IterableDataset so that I can pass it a new batch and reset its internals without having to change the dataloader. The dataset works perfectly when tried in isolation (i.e., it splits the data correctly between 4 processes) but in the full program, it passes duplicate information to each worker. I can’t figure out why (I also tried using the worker_init_fn method but it also failed).

world_size = comm.get_world_size()
rank = comm.get_local_rank()
model = DistributedDataParallel(model, device_ids=[rank], broadcast_buffers=False)

#Instantiate the IterableDataset with dummy data outside the batch loop
inference_dataset = IterableBatchDataset(['a','b','c','d'])

testloader = torch.utils.data.DataLoader(
    dataset=inference_dataset,
    batch_size=bs,
    shuffle=False,            
    num_workers=4,
    worker_init_fn=worker_init_fn,
    collate_fn=trivial_batch_collator)     

for ib in image_batches:
	#Upstream process delivers a batch of 50 to 100 photos
	batch = get_image_batch()

	#Reset the internals of the IterableDataset
	inference_dataset.reset_dataset(batch)

	#Run the batch
	with torch.no_grad():
	    for idx, inputs in enumerate(data_loader):
			outputs = model(inputs)
	    if torch.cuda.is_available():
	        torch.cuda.synchronize()

class IterableBatchDataset(torch.utils.data.IterableDataset):
    """
    An IterableDataset for passing arbitrary-length data to a Dataloader. The
    iter() method ensures that data is not duplicated if there is more than one worker, by splitting
    the data into separate pieces and only passing each worker one piece to iterate.
    """
    def __init__(self, batchlist):
        super(IterableBatchDataset).__init__() #The init method of IterableDataset
        self.data = batchlist
        self.start = 0
        self.end = len(self.data)

    def reset_dataset(self, newdataset):
        self.data = newdataset
        self.start = 0
        self.end = len(newdataset)
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        print("IterableDataset worker_info.id",worker_info.id)
        if worker_info is None:  # single-process data loading, return the full iterator
            indexes = range(self.start,self.end)
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
            indexes = range(iter_start, iter_end)
        itemlist = [self.data[t] for t in indexes] 
        #print("worker_id:",worker_id,"items:",itemlist) #very useful for debugging
        return iter(itemlist)
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]