I was running into some trouble using DistributedDataPrallel and made a simple sanitation script to measure the speedup yielded by torches multiprocessing.
import os
import sys
import torch.distributed as dist
import torch.multiprocessing as mp
import time
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12372'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def main(rank, world_size):
setup(rank, world_size)
time.sleep(5)
cleanup()
if __name__ == "__main__":
world_size = int(sys.argv[1])
start = time.time()
mp.spawn(main,
args=(world_size,),
nprocs=world_size,
join=True)
stop = time.time()
print("world size: ", world_size, "\ntime elapsed: ", stop - start)
I’ve tested with world sizes of [1, 4, 8, 16] on a 32 core machine which resulted in the following outputs:
world size: 1
time elapsed: 7.073131561279297
world size: 4
time elapsed: 8.751604080200195
world size: 8
time elapsed: 9.452412605285645
world size: 16
time elapsed: 13.709965944290161
world size: 32
time elapsed: 22.26715612411499
Am I using mp / dist wrong? Or is an overhead per thread in this scope to be expected?