Torch multiprocessing not yielding nontrivial speedup

I was running into some trouble using DistributedDataPrallel and made a simple sanitation script to measure the speedup yielded by torches multiprocessing.

import os
import sys
import torch.distributed as dist
import torch.multiprocessing as mp
import time

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12372'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup(): 
    dist.destroy_process_group()

def main(rank, world_size):
    setup(rank, world_size)
    time.sleep(5)
    cleanup()


if __name__ == "__main__":
    world_size = int(sys.argv[1])
    start = time.time()
    mp.spawn(main,
             args=(world_size,),
             nprocs=world_size,
             join=True)
    stop = time.time()
    print("world size: ", world_size, "\ntime elapsed: ", stop - start)

I’ve tested with world sizes of [1, 4, 8, 16] on a 32 core machine which resulted in the following outputs:

world size: 1
time elapsed: 7.073131561279297
world size: 4
time elapsed: 8.751604080200195
world size: 8
time elapsed: 9.452412605285645
world size: 16
time elapsed: 13.709965944290161
world size: 32
time elapsed: 22.26715612411499

Am I using mp / dist wrong? Or is an overhead per thread in this scope to be expected?