What is reinitialization and why is that bad?

Hi,

I’m currently writing a distributed training script where there is one Trainer instance that can call its train function multiple times. I wrote this so that I can train a model with different settings sequentially (using a different optimizer, for instance).

The code looks something like this:

class Trainer:
    def __init__(self, loss_fn, device, dataset):
        self.device = device
        self.loss_fn = loss_fn
        self.dataset = dataset
...


    def train(self, epochs):
        torch.cuda.set_device(self.device)
        init_process_group(backend="nccl")

       ...training logic here...

        barrier()  # minimize the chance of the next train function starting before other processes finish
        destroy_process_group()
        sleep(10)

if __name__ == "__main__":
    dataset = CustomDataset(10000)

    loss_fn = nn.CrossEntropyLoss()

    trainer = Trainer(
        loss_fn=loss_fn,
        device=int(os.environ.get("LOCAL_RANK", 0)),
        dataset=dataset,
    )
    trainer.train(epochs=5)
    trainer.train(epochs=5)
    trainer.train(epochs=5)
    trainer.train(epochs=5)
    trainer.train(epochs=5)
    trainer.train(epochs=5)

The problem is I often run into the following error on my 2nd or 3rd execution of trainer.train().

I’m aware that “reinitialization is not recommended” based on the following doc:

However, I’ve used some safe guards like barrier and sleep to increase the chance of destroying the previous default process before initializing a new default process group. The NCCL debug logs also indicate the process group is destroyed as well before that DDP error above.

Question 1:
The DDP error still persists no matter how long I put it to sleep. Why is that the case?

Question 2:
There must be a way to properly reinitialize since things like torchrun probably has to do reinitialization for fault-tolerant training. How does it make sure the reinitialized code doesn’t produce the error like mine??