How do I run Inference in parallel?

Since parallel inference does not need any communication among different processes, I think you can use any utility you mentioned to launch multi-processing. We can decompose your problem into two subproblems: 1) launching multiple processes to utilize all the 4 GPUs; 2) Partition the input data using DataLoader.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run_inference(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    
    # load a model 
    model = YourModel()
    model.load_state_dict(PATH)
    model.eval()
    model.to(rank)

    # create a dataloader
    dataset = ...
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4)

    # iterate over the loaded partition and run the model
    for idx, data in enumerate(loader):
            ...

def main():
    world_size = 4
    mp.spawn(run_inference,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()
3 Likes