Since parallel inference does not need any communication among different processes, I think you can use any utility you mentioned to launch multi-processing. We can decompose your problem into two subproblems: 1) launching multiple processes to utilize all the 4 GPUs; 2) Partition the input data using DataLoader
.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run_inference(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# load a model
model = YourModel()
model.load_state_dict(PATH)
model.eval()
model.to(rank)
# create a dataloader
dataset = ...
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4)
# iterate over the loaded partition and run the model
for idx, data in enumerate(loader):
...
def main():
world_size = 4
mp.spawn(run_inference,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()