Here is an example of how to spawn parallel processes using torch.distributed.launch. You should save this snippet as a python module (say torch_dist_tuto.py ) then run python -m torch.distributed.launch --nproc_per_node=4 torch_dist_tuto.py to launch 4 parallel processes. It sums all numbers from 0 to 9 in a parallel fashion.
import torch
import argparse
import os
import torch.distributed as dist
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
rank = args.local_rank
size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
dist.init_process_group("gloo", rank=rank, world_size=size)
x = torch.arange(10)[rank::size].sum()
print("Process rank {}, partial result {}".format(rank, x))
dist.reduce(x, dst=0)
if rank == 0: print("Final result:", x)
You can adapt this snippet to your code. The rank variable will allow you to control parallelism.
I have written some MPI codes years ago, if I remember right, this is for distributing the parallel code, and then open many process to run. It seems much more like Distributed Parallel (not sure)? And I have only one computer with 4 cards, I want to parallel the forward function within the main process, but using 4 GPU for acceleration. I wonder if this could work too?