Training loop freezes after certain epoch

I am facing an issue where my training loop suddenly freezes after certain epoch, when my dataset is large. I am currently training a model using the BYOL strategy, when I am running a test run with smaller dataset (6 datapoints), the training loop freezes after 6th epoch and continues after sometime, but when I use the larger dataset, the training loop freezes after 6th epoch and the system memory consumption keeps increasing until it is killed.

def train_one_epoch(self, train_loader:DataLoader, optimizer:optim.Optimizer, **kwargs) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Train a single loop on the train_loader.

        Args:
            train_loader (DataLoader): training dataloader
            optimizer (optim.Optimizer): optimizer
        
        Returns:
            torch.Tensor, Tuple[torch.Tensor, torch.Tensor]: validation loss, (augmented view 1, augmented view 2) for tensorboard logging
        """

        running_training_loss = 0
        scaler = amp.grad_scaler.GradScaler(device.type)
        self.online_model.train()

        pbar = tqdm(train_loader, desc='Validating...', total=len(train_loader), unit='batch', leave=False, position=1)

        for _, (view1, view2) in enumerate(pbar):
            view1, view2 = view1.to(device), view2.to(device)
            optimizer.zero_grad(set_to_none=True)

            if kwargs['mixed_precision'] == True:
                with amp.autocast_mode.autocast(device.type):
                    o1 = self.online_model(view1)
                    o2 = self.online_model(view2)

                    with torch.no_grad():
                        t1 = self.target_model(view1)
                        t2 = self.target_model(view2)

                    loss = (self.calculate_loss(o1, t2.detach()) + self.calculate_loss(o2, t1.detach())).mean()
                    scaler.scale(loss).backward()                  
                    scaler.step(optimizer)
                    scaler.update()
                
            else:
                o1 = self.online_model(view1)
                o2 = self.online_model(view2)

                with torch.no_grad():
                    t1 = self.target_model(view1)
                    t2 = self.target_model(view2)
                
                loss = (self.calculate_loss(o1, t2.detach()) + self.calculate_loss(o2, t1.detach())).mean()
                loss.backward()
                optimizer.step()

            running_training_loss += loss.item() / len(train_loader)        
            pbar.set_postfix({'Training loss': running_training_loss})
            pbar.update()

            torch.cuda.empty_cache()

        return running_training_loss, (view1[0], view2[0])

    def validate(self, validation_loader:DataLoader, lr_scheduler:optim.lr_scheduler.LRScheduler) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Run a single validation loop for the model

        Args:
            validation_loader (DataLoader): validation loader
            lr_scheduler (LRScheduler): learning rate scheduler

        Returns:
            torch.Tensor, Tuple[torch.Tensor, torch.Tensor]: validation loss, (augmented view 1, augmented view 2) for tensorboard logging
        """

        running_validation_loss = 0
        self.online_model.eval()
        self.target_model.eval()

        pbar = tqdm(validation_loader, desc='Validating...', total=len(validation_loader), unit='batch', leave=False, position=1)

        with torch.no_grad():
            for _, (view1, view2) in enumerate(pbar):
                view1, view2 = view1.to(device), view2.to(device)
                
                o1 = self.online_model(view1)
                o2 = self.online_model(view2)

                t1 = self.target_model(view1)
                t2 = self.target_model(view2)

                loss = (self.calculate_loss(o1, t2) + self.calculate_loss(o2, t1)).mean()
                lr_scheduler.step(loss)

                running_validation_loss += loss.item() / len(validation_loader)
                pbar.set_postfix({'Validation loss': running_validation_loss})
                pbar.update()

                torch.cuda.empty_cache()

        return running_validation_loss, (view1[0], view2[0])
    
    def train(self,
              dataloader:Union[DataLoader, Tuple[DataLoader, DataLoader]],
              epochs:int=100,
              lr_rate:float=1e-4,
              optimizer:optim.Optimizer=optim.Adam,
              lr_scheduler:optim.lr_scheduler.LRScheduler=optim.lr_scheduler.ReduceLROnPlateau,
              logger:Logger=None,
              save_dir:str="weights/SegFormer/encoder/",
              **kwargs) -> nn.Module:
        r"""
        Implements the BYOL training strategy

        Args:
            dataloader (Union[DataLoader, Tuple[DataLoader, DataLoader]]): traning and validation loader (if any)
            epochs (int): number of epochs
            lr_rate (float): learning rate
            optimizer (optim.Optimizer): training optimizer
            lr_scheduler (optim.lr_scheduler.LRScheduler): learning rate scheduler
            logger (Logger): tensorboard logger
            save_dir (str): model save directory

        Returns:
            nn.Module: trained encoder
        """
        
        train_loader, validation_loader = None, None
        if isinstance(dataloader, list) or isinstance(dataloader, tuple):
            train_loader, validation_loader = dataloader
        else:
            train_loader = dataloader

        if logger is not None:
            logger.log_graph(input=torch.randn(1, 3, self.target_model.image_size, self.target_model.image_size).to(device))

        optimizer = optimizer(self.online_model.parameters(), lr_rate, eps=1e-4, weight_decay = kwargs['weight_decay'])
        lr_scheduler = lr_scheduler(optimizer, mode="min", threshold=0.01, patience=3, min_lr=1e-5, eps=1e-3)

        profiler = logger.gpu_profiler()
        profiler.start()

        pbar = tqdm(range(epochs), desc="Training...", unit="epoch", bar_format="{l_bar}{bar}{r_bar}", dynamic_ncols=True, colour="#1f5e2a", position=0, leave=True)
        best_validation_loss = 2.
        for epoch in pbar:
            training_loss, (view1, view2) = self.train_one_epoch(train_loader, optimizer, **kwargs)
            profiler.step()
            pbar.set_postfix({"Training Loss": training_loss})

            if validation_loader is not None:
                validation_loss, (view1, view2) = self.validate(validation_loader, lr_scheduler)
                
            pbar.set_postfix({"Training Loss": training_loss,
                              "Validation Loss": validation_loss,
                              "Learning Rate": lr_scheduler.get_last_lr()[-1]})

            if logger is not None:
                logger.log_scaler(epoch, "Learning Rate", lr_scheduler.get_last_lr()[-1])
                logger.log_scaler(epoch, "Metrics/Training Loss", training_loss)
                logger.log_scaler(epoch, "Metrics/Validation Loss", validation_loss)
                logger.log_histogram(epoch, "Weights Distribution/Online Model Embedding Layer", [p for p in self.online_model.named_parameters()][-4][1])
                logger.log_histogram(epoch, "Weights Distribution/Target Model Embedding Layer", [p for p in self.target_model.named_parameters()][-2][1])
                
                if epoch % 10 == 0:
                    logger.log_images(epoch, "Images/View 1", view1)
                    logger.log_images(epoch, "Images/View 2", view2)
                    with torch.no_grad():
                        logger.log_images(epoch, "Model Ouput/Online Model", view_attention(self.online_model_encoder.encoder(view1.unsqueeze(0))[0]))
                        logger.log_images(epoch, "Model Output/Target Model", view_attention(self.target_model.encoder(view1.unsqueeze(0))[0]))

            if epoch % 10 == 0:
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir, exist_ok=True)

                if validation_loss < best_validation_loss:
                    save_model(self.online_model_encoder, f"SegFormer_{self.online_model_encoder.encoder_variant}_{train_loader.batch_size}_{epoch}", save_dir)
                    best_validation_loss = validation_loss

            pbar.update()
        profiler.stop()

        if logger is not None:
            logger.log_hparams(hparams={"batch_size": train_loader.batch_size, "learning_rate": lr_rate}, 
                               metrics={"Training Loss": training_loss, "Validation Loss": validation_loss}
                               )

        return self.online_model_encoder

Could you remove the usage of tqdm as well as the profiling to check if the memory increase could be related to these packages?

Thank you for the suggestion @ptrblck. I identified the issue within my code, which stemmed from the large JSON file generated by the TensorBoard profiler. Specifically, when the active state was set to 1, the JSON file size was approximately 6.5 GB.

prof = torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir),
            with_stack=True,
            with_flops=True,
            profile_memory=True)

Initially, I had set the active state to 3, as indicated in the PyTorch documentation, resulting in a JSON file size of around 20 GB. This excessive file size became a significant bottleneck, as the process of writing and reading such a large file consumed substantial system memory and swap memory, ultimately leading to the termination of the code.

Good to hear you’ve isolated the issue!