Cannot re-initialize CUDA in forked subprocess on network.to(device) operation

Hello,

I am trying to implement the DistributedDataParallel class in my training code.
The training code is a block in a larger block that I run to do the training and logging. Because the larger block runs twice when the multiprocess initialization method is set to ‘spawn’ and rewriting my ‘main’ function would be too much work, I looked into forking the subprocess, so only the trainingsblock is run in parrallel.

The way I initialize the subprocesses now is (which i stole from here:

        import torch.multiprocessing as mp

        mp.set_start_method('fork', force=True)
        for rank in range(self.world_size):
            p = mp.Process(target=self.train, args=(rank,))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()

Now the train function crashes on the “net.to(device)” line in the code sample below, with the error message:
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the ‘spawn’ start method

        self.net.to(rank)
        self.net = DistributedDataParallel(self.net, device_ids=[rank])

So far I’ve read that this can happen when something is already initialised on the cuda before the multiprocessing starts. But I can’t find where that would be in my code (I checked and removed all the .to(device) operations).

Other possible causes that I found could be related to the dataloader (num_workers to 0 and pin_memory), but placing the dataloader code after the network initialisation code, still gives the same error.

Is it possible to use the DistributedDataParallel class this way?
Can the cuda be cleared before initialising the processes?

Thanks in advance

thanks for posting @Pascal_Niville, this is a known issue for cuda runtime, you can see a related issue here Cannot re-initialize CUDA in forked subprocess · Issue #40403 · pytorch/pytorch · GitHub. The workarond is to use “spawn” instead of “fork” as suggested in the error.