DDP Training Issue

Hi I am facing issue in training my model in DDP setting. I have a warmup phase and then training phase. The warmup phase runs smoothly without any issue, while in the training phase, I am training and validating at each epoch. The first epoch of train and validation runs smoothly, while when the second epoch begins, it gets stuck at the last iteration of second epoch. This is my code

     def warmup_one_epoch(self, dataloader:DataLoader, lr_scheduler:optim.lr_scheduler.LRScheduler) -> None:

        self.model.train()
        self.optimizer.zero_grad(set_to_none=True)

        metric = {}
        running_losses = {}

        scaler = amp.grad_scaler.GradScaler(self.__ddp_config['device'].type)
        
        if is_main_process():
            progress = tqdm(dataloader, desc='Warming_up...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
        else:
            progress = dataloader

        with self.model.join():
            for epoch, batch in enumerate(progress):
                image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
                target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)

                if ((epoch + 1) % self.kwargs['num_updates']) != 0 and hasattr(self.model, 'no_sync'):
                    sync_context = self.model.no_sync()
                else:
                    sync_context = contextlib.nullcontext()

                with sync_context:
                    if self.kwargs['mixed_precision']:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, _ = self.criterion(pred, target)
                            total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
                            
                        scaler.scale(total_loss).backward()
                        
                    else:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, _ = self.criterion(pred, target)
                            total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']

                        total_loss.backward()

                    if ((epoch+1)%self.kwargs['num_updates'] == 0) or (epoch+1 == len(dataloader)):
                        if self.kwargs['mixed_precision']:
                            scaler.unscale_(self.optimizer)
                            # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                            scaler.step(self.optimizer)
                            scaler.update()
                        else:
                            # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.)
                            self.optimizer.step()
                        
                        self.optimizer.zero_grad(set_to_none=True)
                            
                for key, value in loss.items():
                    if not key in running_losses.keys():
                        running_losses[key] = 0.
                    running_losses[key] += self.reduce_metrics(value.detach() / self.kwargs['num_updates']).item() / len(dataloader)

                if is_main_process():
                    current_lr = self.optimizer.param_groups[0]['lr']
                    progress.set_postfix({
                        'lr': f"{current_lr:.3f}", 
                        **{k: f"{v:.3f}" for k, v in running_losses.items()}
                    })
                
        if is_main_process():
            progress.close()

        running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
        
        lr_scheduler.step()
        metric["warmup_loss"] = running_losses
        torch.cuda.synchronize()
        
        gc.collect()
        torch.cuda.empty_cache()
        torch.clear_autocast_cache()

        return metric

    def train_one_epoch(self, dataloader:DataLoader) -> dict:
        self.model.train()
        self.optimizer.zero_grad(set_to_none=True)

        metric = {}
        running_losses = {}
        running_metrics = {}

        scaler = amp.grad_scaler.GradScaler(self.__ddp_config['device'].type)
        
        if is_main_process():
            progress = tqdm(dataloader, desc='Training...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
        else:
            progress = dataloader

        with self.model.join():
            for epoch, batch in enumerate(progress):
                image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
                target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)

                if ((epoch + 1) % self.kwargs['num_updates']) != 0 and hasattr(self.model, 'no_sync'):
                    sync_context = self.model.no_sync()
                else:
                    sync_context = contextlib.nullcontext()

                with sync_context:
                    if self.kwargs['mixed_precision']:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, idx = self.criterion(pred, target)
                            total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']
                            
                        scaler.scale(total_loss).backward()
                        
                    else:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, idx = self.criterion(pred, target)
                            total_loss = (loss['cls_loss'] + loss['mask_loss']) / self.kwargs['num_updates']

                        total_loss.backward()
        
                    if ((epoch+1)%self.kwargs['num_updates'] == 0) or (epoch+1 == len(dataloader)):
                        if self.kwargs['mixed_precision']:
                            scaler.unscale_(self.optimizer)
                            scaler.step(self.optimizer)
                            scaler.update()
                        else:
                            self.optimizer.step()
                    
                        self.optimizer.zero_grad(set_to_none=True)
                            
                for key, value in loss.items():
                    if not key in running_losses.keys():
                        running_losses[key] = 0.
                    running_losses[key] += self.reduce_metrics(value / self.kwargs['num_updates']).item() / len(dataloader)

                # cls_metric, instance_metric, semantic_metric = self.metric(pred, target, idx)
                # for (cls_key, cls_val), (sem_key, sem_val) in zip(cls_metric.items(), semantic_metric.items()):
                #     cls_metric[cls_key] = self.reduce_metrics(cls_val.detach()).item()
                #     semantic_metric[sem_key] = self.reduce_metrics(sem_val.detach()).item()

                # running_metrics['cls_metrics'] = cls_metric
                # running_metrics['instance_metrics'] = instance_metric
                # running_metrics['semantic_metrics'] = semantic_metric
                # running_metrics = add_metrics(running_metrics, running_metrics, len(dataloader))

                if is_main_process():
                    current_lr = self.optimizer.param_groups[0]['lr']
                    progress.set_postfix({
                        'lr': f"{current_lr:.3f}", 
                        **{k: f"{v:.3f}" for k, v in running_losses.items()}
                    })
                
        if is_main_process():
            progress.close()

        running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
        
        metric['train_loss'] = running_losses
        metric['train_metrics'] = running_metrics
        torch.cuda.synchronize()
        
        gc.collect()
        torch.cuda.empty_cache()
        torch.clear_autocast_cache()
        
        return metric
    
    def validate_one_epoch(self, dataloader:DataLoader) -> dict:
        self.model.eval()
        
        metric = {}
        running_losses = {}
        running_metrics = {}
        
        if is_main_process():
            progress = tqdm(dataloader, desc='Validating...', leave=False, position=1, colour='#0b83de', dynamic_ncols=True)
        else:
            progress = dataloader
        
        with torch.no_grad():
            with self.model.join():
                for _, batch in enumerate(progress):
                    image = batch[0].to(self.__ddp_config['device'], non_blocking=True)
                    target = move_to_device(batch[1], self.__ddp_config['device'], non_blocking=True)

                    if self.kwargs['mixed_precision']:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, idx = self.criterion(pred, target)

                    else:
                        with amp.autocast_mode.autocast(device_type=self.__ddp_config['device'].type, dtype=self.kwargs['autocast_dtype']):
                            pred = self.model(image)
                            loss, idx = self.criterion(pred, target)
                    
                    for key, value in loss.items():
                        if not key in running_losses.keys():
                            running_losses[key] = 0.
                        running_losses[key] += self.reduce_metrics(value.detach() / self.kwargs['num_updates']).item() / len(dataloader)

                    # cls_metric, instance_metric, semantic_metric = self.metric(pred, target, idx)
                    # for (cls_key, cls_val), (sem_key, sem_val) in zip(cls_metric.items(), semantic_metric.items()):
                    #     cls_metric[cls_key] = self.reduce_metrics(cls_val.detach()).item()
                    #     semantic_metric[sem_key] = self.reduce_metrics(sem_val.detach()).item()

                    # running_metrics['cls_metrics'] = cls_metric
                    # running_metrics['instance_metrics'] = instance_metric
                    # running_metrics['semantic_metrics'] = semantic_metric
                    # running_metrics = add_metrics(running_metrics, running_metrics, len(dataloader))
                    
                    if is_main_process():
                        current_lr = self.lr_scheduler.get_last_lr()[0]
                        progress.set_postfix({
                            'lr': f"{current_lr:.3f}", 
                            **{k: f"{v:.3f}" for k, v in running_losses.items()}
                        })
        
        if is_main_process():
            progress.close()

        running_losses['total_loss'] = running_losses['cls_loss'] + running_losses['mask_loss']
        
        self.lr_scheduler.step(running_losses['total_loss'])
        metric['val_loss'] = running_losses
        metric['val_metrics'] = running_metrics
        torch.cuda.synchronize()
        
        gc.collect()
        torch.clear_autocast_cache()
        torch.cuda.empty_cache() 

        return metric
    
    def __call__(self, dataloader:Union[DataLoader, Tuple[DataLoader, DataLoader]]) -> None:
         train_loader = val_loader = None
        if len(dataloader) == 2:
            train_loader, val_loader = dataloader
        elif len(dataloader) == 1:
            train_loader = dataloader
        else:
            raise RuntimeError(f"length of dataloader is {len(dataloader)}, supports only 1 or 2")
        
        if self.logger is not None and is_main_process():
            self.logger.log_graph(torch.randn(1, 3, self.image_size, self.image_size,
                                         device=next(self.model.parameters()).device))
        
        metric = {}
        progress = None
        best_val_loss = float("inf")
        
        if self.kwargs['warmup_epochs'] > 0 and not os.path.isfile(f"{self.kwargs['snapshot_dir']}/{self.model.module.variant}.pt"):
            warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: self.warmup_lr_schedule(epoch=epoch, warmup_epochs=self.kwargs['warmup_epochs'], initial_lr=self.kwargs['lr']/(dist.get_world_size()*self.kwargs['warmup_epochs']/2), final_lr=self.kwargs['lr']/dist.get_world_size()))
            
            if is_main_process():
                progress = tqdm(range(self.kwargs['warmup_epochs']), desc='Warmup Phase...', leave=True, position=0, colour='#1066a3', dynamic_ncols=True)
            else:
                progress = range(self.kwargs['warmup_epochs'])
                
            for epoch in progress:
                train_loader.sampler.set_epoch(epoch)
                warmup_metrics = self.warmup_one_epoch(train_loader, self.optimizer, warmup_scheduler)

                if is_main_process():
                    current_lr = warmup_scheduler.get_last_lr()[0]
                    progress.set_postfix({
                        'lr': f"{current_lr:.3f}", 
                        **{k: f"{v:.3f}" for k, v in warmup_metrics['warmup_loss'].items()}
                    })
                dist.barrier()

            if is_main_process():
                self._save_snapshot(epoch=0)
                progress.close()
            dist.barrier()
                
        if is_main_process():
            print("\n")
            progress = tqdm(range(self.epoch, self.kwargs['epochs']), desc='Train Validation Phase...', leave=True, position=0, colour='#1066a3')
        else:
            progress = range(self.epoch, self.kwargs['epochs'])
        
        for epoch in progress:
            train_loader.sampler.set_epoch(epoch)
            train_metrics = self.train_one_epoch(train_loader)
            if val_loader is not None:
                val_loader.sampler.set_epoch(epoch)
                val_metrics = self.validate_one_epoch(val_loader)

            metric['learning_rate'] = self.optimizer.param_groups[0]['lr'] * dist.get_world_size()
            metric['train_metric'] = train_metrics
            metric['val_metric'] = val_metrics if val_loader is not None else None

            if is_main_process():
                current_lr = self.optimizer.param_groups[0]['lr']
                progress.set_postfix({
                            'lr': f"{current_lr:.3f}", 
                            'train_loss': f"{train_metrics['train_loss']['total_loss']:.4f}", 
                            'val_loss': f"{val_metrics['val_loss']['total_loss']:.4f}"
                        })

                if (epoch+1) % self.kwargs['snapshot_log_interval'] == 0:
                    self._save_snapshot(epoch+1)

                if val_metrics['val_loss']['total_loss'] < best_val_loss:
                    best_val_loss = val_metrics['val_loss']['total_loss']
                    self._check_dir(self.kwargs['checkpoint_dir'])
                    save_model(self.model.module, f"RTSeg_{self.model.module.variant}_batch_{self.kwargs['batch_size']}_best", self.kwargs['checkpoint_dir'])

                self._check_dir(self.kwargs['log_dir'])
                self.log_metrics(metric, epoch+1)
            
            self.epoch = epoch + 1
            dist.barrier()

I am using torch==2.5.1 with 3 Ada 6000 GPUs