Unable to complete training on Multi-GPU setup

I have three RTX 6000 Ada gpu, i am running my script using torchrun by setting --nproc_per_node = 3. I always face issue with my last epoch, the training is stuck. I am running Ubuntu 24.04, python 3.12 and torch 2.5.1+cu124

[rank0]: Traceback (most recent call last):                                                                                                                                                
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/train.py", line 463, in <module>
[rank0]:     main()
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/train.py", line 459, in main
[rank0]:     prepare_training(model, train_directory, val_directory, batch_size=mini_batch_size, split=split, logger=logger, cache=cache, args=train_args)
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/train.py", line 385, in prepare_training
[rank0]:     train(model, trainloader=train_loader, valloader=validation_loader, logger=logger, args=args)
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/train.py", line 302, in train
[rank0]:     trained_model = trainer.train(dataloader=(trainloader, valloader),
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/../trainer/policy.py", line 664, in train
[rank0]:     validation_loss = self.validate(validation_loader, lr_scheduler, **kwargs)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/../trainer/policy.py", line 589, in validate
[rank0]:     running_loss += self.reduce_metrics(loss / len(validation_loader))
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/robovision/Documents/Rebanta/Code/RTSeg_v2/Pretraining/trainer/../trainer/policy.py", line 358, in reduce_metrics
[rank0]:     dist.all_reduce(metric, op=dist.ReduceOp.SUM)
[rank0]:   File "/home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2501, in all_reduce
[rank0]:     work = group.allreduce([tensor], opts)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: torch.distributed.DistBackendError: NCCL communicator was aborted on rank 0. 
[rank0]:[E517 12:37:45.205892340 ProcessGroupNCCL.cpp:1595] [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=124, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=300000) ran for 300081 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x76e8de76c446 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x76e8941cc772 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x76e8941d3bb3 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x76e8941d561d in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x76e8deb9e5c0 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x9caa4 (0x76e8e009caa4 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x129c3c (0x76e8e0129c3c in /lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=124, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=300000) ran for 300081 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x76e8de76c446 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x76e8941cc772 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x76e8941d3bb3 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x76e8941d561d in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x145c0 (0x76e8deb9e5c0 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x9caa4 (0x76e8e009caa4 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x129c3c (0x76e8e0129c3c in /lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1601 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x76e8de76c446 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe4271b (0x76e893e4271b in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x145c0 (0x76e8deb9e5c0 in /home/robovision/.pyvenv/deeplearning/lib/python3.12/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x9caa4 (0x76e8e009caa4 in /lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x129c3c (0x76e8e0129c3c in /lib/x86_64-linux-gnu/libc.so.6)

I tried increasing the timeout by

init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=300))

this is my main train code

    def train(self,
          dataloader: Union[DataLoader, Tuple[DataLoader, DataLoader]],
          epochs: int = 100,
          lr_rate: float = 1e-4,
          optimizer: optim.Optimizer = optim.Adam,
          lr_scheduler: optim.lr_scheduler.LRScheduler = optim.lr_scheduler.ReduceLROnPlateau,
          warmup_scheduler: optim.lr_scheduler.LRScheduler = optim.lr_scheduler.LambdaLR,
          logger: Logger = None,
          save_dir: str = "./weights/RTSeg/pretraining/",
          **kwargs) -> nn.Module:

        train_loader, validation_loader = (dataloader if isinstance(dataloader, (list, tuple)) else (dataloader, None))

        if logger is not None and is_main_process():
            dummy_input = torch.randn(1, 3, self.model.image_size, self.model.image_size).to(
                local_rank, dtype=torch.bfloat16 if self.kwargs.get("use_bfloat16", False) else torch.float32
            )
            logger.log_graph(input=dummy_input)

        optimizer = optimizer(self.online_model.parameters(), lr_rate, eps=1e-5, weight_decay=kwargs["weight_decay"])
        lr_scheduler = lr_scheduler(optimizer, factor=0.5, mode="min", threshold=0.001, patience=3, min_lr=1e-5, eps=1e-5)

        if kwargs["warmup_epochs"] != 0:
            warmup_scheduler = warmup_scheduler(optimizer, lr_lambda=lambda epoch: self.warmup_lr_schedule(
                epoch=epoch,
                warmup_epochs=kwargs["warmup_epochs"],
                initial_lr=lr_rate / (kwargs["warmup_epochs"] / 2),
                final_lr=lr_rate
            ))

            warmup_iter = range(kwargs['warmup_epochs'])
            if is_main_process():
                warmup_iter = tqdm(warmup_iter, desc="Warmup...", unit="epoch", bar_format="{l_bar}{bar}{r_bar}",
                                dynamic_ncols=True, colour="#1f5e2a", position=0, leave=True)

            for epoch in warmup_iter:
                if hasattr(train_loader.sampler, "set_epoch"):
                    train_loader.sampler.set_epoch(epoch)
                self.warmup(train_loader, optimizer, warmup_scheduler, **kwargs)
                if is_main_process() and isinstance(warmup_iter, tqdm):
                    warmup_iter.set_postfix({"Warmup LR": optimizer.param_groups[0]['lr']})

        best_validation_loss = float("inf")
        epoch_iter = range(self.epoch, epochs)
        if is_main_process():
            epoch_iter = tqdm(epoch_iter, desc="Training...", unit="epoch", bar_format="{l_bar}{bar}{r_bar}",
                            dynamic_ncols=True, colour="#1f5e2a", position=0, leave=True)

        for epoch in epoch_iter:
            train_loader.sampler.set_epoch(epoch)
            self.epoch = epoch
            training_loss = self.train_one_epoch(train_loader, optimizer, **kwargs)

            validation_loss = float("nan")
            if validation_loader is not None:
                validation_loader.sampler.set_epoch(epoch)
                validation_loss = self.validate(validation_loader, lr_scheduler, **kwargs)

            if is_main_process() and isinstance(epoch_iter, tqdm):
                dist.barrier()
                epoch_iter.set_postfix({
                    "Training Loss": training_loss,
                    "Validation Loss": validation_loss,
                    "Learning Rate": optimizer.param_groups[0]["lr"]
                })

            if is_main_process():
                if (epoch + 1) % kwargs["snapshot_log_interval"] == 0:
                    self._save_snapshot(epoch)
                    dist.barrier()

                if logger is not None:
                    logger.log_scaler(epoch, "Learning Rate", optimizer.param_groups[0]["lr"])
                    logger.log_scaler(epoch, "Metrics/Training Loss", training_loss)
                    if validation_loader is not None:
                        logger.log_scaler(epoch, "Metrics/Validation Loss", validation_loss)

                    for (online_name, online_param), (target_name, target_param) in zip(
                        self.online_model.named_parameters(), self.target_model.named_parameters()):
                        logger.log_histogram(epoch, f"Weights/Online Model/{online_name}", online_param)
                        logger.log_histogram(epoch, f"Weights/Target Model/{target_name}", target_param)

                    if (epoch + 1) % kwargs["image_log_interval"] == 0 and validation_loader is not None:
                        test_view1 = validation_loader.dataset[0][0].unsqueeze(0).to(local_rank, dtype=torch.bfloat16 if self.kwargs.get("use_bloat16", True) else torch.float32)
                        test_view2 = validation_loader.dataset[0][1].unsqueeze(0).to(local_rank, dtype=torch.bfloat16 if self.kwargs.get("use_bloat16", True) else torch.float32)

                        with torch.no_grad():
                            self.online_model.eval()
                            view1_attn = self.online_model.module.model.get_last_attention(test_view1, processed=True)
                            view2_attn = self.online_model.module.model.get_last_attention(test_view2, processed=True)

                        logger.log_images(epoch, "Images/View 1", test_view1.squeeze(0).to("cpu"))
                        logger.log_images(epoch, "Images/View 2", test_view2.squeeze(0).to("cpu"))

                        for n, (i, j) in enumerate(zip(view1_attn, view2_attn)):
                            logger.log_salient_maps(epoch, f"Attention/View 1/Transformer Block {n}", i)
                            logger.log_salient_maps(epoch, f"Attention/View 2/Transformer Block {n}", j)

                if (epoch + 1) % 2 == 0 and validation_loader is not None:
                    if validation_loss < best_validation_loss:
                        save_model(self.model, f"RTSeg_{self.model.variant}_batch_{train_loader.batch_size}_momentum_{self.momentum}_best", save_dir)
                        best_validation_loss = validation_loss

        if is_main_process() and isinstance(epoch_iter, tqdm):
            epoch_iter.close()
            
        if is_main_process():
            dist.barrier()
            dist.destroy_process_group()
        
            print("\nTraining Complete !!!")

and my reduce_metrics method

    def reduce_metrics(self, metric:torch.Tensor, mode:Literal["sum", "average"]="average") -> torch.Tensor:
        """
        Reduce the metric across all GPUs.

        Args:
            metric (torch.Tensor): metric to reduce

        Returns:
            torch.Tensor: reduced metric
        """

        if dist.is_initialized():
            metric = metric.detach().clone()
            dist.all_reduce(metric, op=dist.ReduceOp.SUM)
            if mode == "average":
                metric = metric / dist.get_world_size()
            elif mode == "sum":
                metric = metric

        return metric

Did you check if all ranks process the same number of batches? If that’s not the case the finished rank would block the communication of the others triggering the timeout eventually.

I am actually new to DDP, could you guide me how to check this?

Assuming you are using a map-style dataset and a DataLoader you could check the len(loader)` on each rank making sure the same number of batches is returned.

Hi @ptrblck sorry for the late reply, I checked my GPU ID, batch size and iteration by printing

print(f"{dist.get_rank()}: {view1.shape[0]}, {epoch}/{len(dataloader)}")

and noted that for total number of iteration of 13346, all my iterations had a input shape 0 as 32, while on the last iteration it was 15 across all the GPUs.

I am currently exploring join context manager as a potential solution.