Pytorch DDP basic help needed

I am trying to learn to use DDP, so I can move my stuff over from DataParallel to DDP.

I am doing this in a Jupyter notebook. I am trying to just get something basic to run, no model, just print that its even alive. I have 1 node, with 4 GPU’s.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, param):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, param['world_size'])
    
    cleanup() 

if __name__ == '__main__':
    param['world_size'] = param['num_gpus']                                   
    mp.spawn(train, args=(param,), nprocs=param['world_size'], join=True)         

When I execute this, it just hangs, I get no output. As you can see I try to print from my train function and I don’t see that at all either.