Model is training infinitely

Having read similar issues posted on the forum, suggested approaches are not working and I am greatly in need of your suggestion(s). Using google colab, I am training on unet model which takes an input [1,1, 512, 512] -> B,C,H, W

type or paste code here
```def train_model(model, dataloaders, criterion, optimizer, num_epochs):
    since = time.time()

    container = {"train": {"loss": {"pred1":[], "pred2":[]}, 
                           "score":{"pred1":[], "pred2":[]}},
                 "val": {"loss": {"pred1":[], "pred2":[]},
                         "score":{"pred1":[], "pred2":[]}},
                 "learning_rate":[]}     

    best_model_wts = copy.deepcopy(model.state_dict())
    best_score = 0.0

    for epoch in range(num_epochs):
        
        start_time = time.time()
        
        epoch_loss = {"train":{"loss1":0, "loss2":0}, 
                      "val":{"loss1":0, "loss2":0}}
        epoch_score = {"train":{"score1":0, "score2":0},
                       "val":{"score1":0, "score2":0}}

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss1, running_loss2 = 0.0, 0.0
            running_score1, running_score2  = 0.0, 0.0

            # Iterate over data.
            for data in dataloaders[phase]:                
                inputs, label_pred1, label_pred2 = data
                                
                inputs = inputs.to(device) 
                label_pred1 = label_pred1.to(device)
                label_pred2 = label_pred2.to(device) 
                labels = torch.cat([label_pred1, label_pred2], dim=1) # not really needed
                
                # zero the parameter gradients
                optimizer.zero_grad()
                
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # use pretrained weight of others architecture                    
                    outputs = model(inputs)                    
                    loss1 = criterion(outputs[:, [0,1,2], :,:], label_pred1) 
                    loss2 = criterion(outputs[:, [3,4,5,6,7,8], :,:], label_pred2)
    
                    dice_coefficient_pred1 = dice_no_threshold(outputs[:, [0,1,2], :,:].detach().cpu(), label_pred1).item()
                    dice_coefficient_pred2 = dice_no_threshold(outputs[:, [3,4,5,6,7,8], :,:].detach().cpu(), label_pred2).item()
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss1.backward(retain_graph=True) 
                        loss2.backward()  
                        optimizer.step()
                
                # statistics
                running_loss1 += loss1.item() * inputs.size(0)
                running_loss2 += loss2.item() * inputs.size(0)
                running_score1 += dice_coefficient_pred1 * inputs.size(0)
                running_score2 += dice_coefficient_pred2 * inputs.size(0)
                
            # if phase == 'train':
            #     scheduler.step()
                
            epoch_loss[phase]["loss1"] = running_loss1 / len(dataloaders[phase].dataset)
            epoch_loss[phase]["loss2"] = running_loss2 / len(dataloaders[phase].dataset)
            epoch_score[phase]["score1"] = running_score1 / len(dataloaders[phase].dataset)
            epoch_score[phase]["score2"] = running_score2 / len(dataloaders[phase].dataset)

            # print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss[phase], epoch_acc[phase]))
            
            # deep copy the model !modify this
            # if phase == 'val' and epoch_score["val"]["score1"] > best_score:
            #     best_score = epoch_score["val"]
            #     best_model_wts = copy.deepcopy(model.state_dict())
                
            # storing experiment result for visualization
            if phase == 'val': 
                container["val"]["loss"]["pred1"].append(epoch_loss["val"]["loss1"])
                container["val"]["loss"]["pred2"].append(epoch_loss["val"]["loss2"])
                container["val"]["score"]["pred1"].append(epoch_score["val"]["score1"])
                container["val"]["score"]["pred2"].append(epoch_score["val"]["score2"])
            else:
                container["train"]["loss"]["pred1"].append(epoch_loss["train"]["loss1"])
                container["train"]["loss"]["pred2"].append(epoch_loss["train"]["loss2"])
                container["train"]["score"]["pred1"].append(epoch_score["train"]["score1"])
                container["train"]["score"]["pred2"].append(epoch_score["train"]["score2"])
        
        # container["learning_rate"].append([param_group['lr'] for param_group in optimizer.param_groups])              
        training_time = str(datetime.timedelta(seconds=time.time() - start_time))[:7]
        
        print("Epoch: {}/{}".format(epoch+1, num_epochs),
              "Training | loss1: {:.4f}".format(epoch_loss["train"]["loss1"]), "score1: {:.4f}".format(epoch_score["train"]["score1"]),
              "loss2: {:.4f}".format(epoch_loss["train"]["loss2"]), "score2: {:.4f}".format(epoch_score["train"]["score2"]),
              "Validation | loss1: {:.4f}".format(epoch_loss["val"]["loss1"]), "score1: {:.4f}".format(epoch_score["val"]["score1"]),
              "loss2: {:.4f}".format(epoch_loss["val"]["loss2"]), "score2: {:.4f}".format(epoch_score["val"]["score2"]), 
            #   "|Time: {}".format(training_time)
              )
        
    print()    
    print("Training & Validation Workflow Completed")
    print("="*40)
    time_elapsed = str(datetime.timedelta(seconds=time.time() - since))[:7]
    print('Total estimated time {}'.format(time_elapsed))
    print('Best validation accuracy: {:4f}'.format(best_score)) # best acc, loss, epoch

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, container

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    # hyper-parameters
    parser.add_argument("--epoch", type=int, default=3, help="epoch_number")
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
    parser.add_argument('--trainsize', type=int, default=512, help='set the size of training sample')
    parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
    parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
    parser.add_argument('--num_workers', type=int, default=0,help='number of workers in dataloader. In windows, set num_workers=0')
    
    # training dataset
    parser.add_argument('--train_path', type=str,
                        default='./Dataset/TrainingSet/LungInfection-Train/Doctor-label')
    parser.add_argument('--train_save', type=str, default=None,
                        help='If you use custom save path, please edit `--is_semi=True` and `--is_pseudo=True`')
    
    # model_lung_infection parameters
    parser.add_argument('--net_channel', type=int, default=32,
                        help='internal channel numbers in the Inf-Net, default=32, try larger for better accuracy')
    parser.add_argument('--n_classes', type=int, default=9,
                        help='binary segmentation when n_classes=1')
    parser.add_argument('--backbone', type=str, default='Res2Net50',
                        help='change different backbone, choice: VGGNet16, ResNet50, Res2Net50')
    
    opt = parser.parse_args()
    
    # ---- device setting ----
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # ---- pretrained architecture ----
    
    
    # ---- build models ----    
    model = UNet(input_channels=1, output_channels=opt.n_classes, outputs_activation="softmax")
    model = model.to(device)

    criterion = DiceLoss(activation="softmax2d")
    optimizer = torch.optim.Adam(model.parameters(), opt.lr, weight_decay = opt.decay_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=2, cooldown=2)
    current_lr = [param_group['lr'] for param_group in optimizer.param_groups][0]
    
    dataloaders = get_loader(opt.batchsize, opt.trainsize, opt.num_workers)

    model, container = train_model(model, dataloaders, criterion, optimizer, opt.epoch)

@ptrblck I will be glad for your support

If I understand the issue correctly, your training doesn’t stop at all?
If so, could you add print statements and make sure that the training loop is indeed executed and the code isn’t hanging at one point?

1 Like

It does not complete the first epoch after several hours of training. And at one point, the CUDA Out of Memory error message prompt:
RuntimeError: CUDA out of memory. Tried to allocate 12.50 MiB (GPU 0; 10.92 GiB total capacity; 8.57 MiB already allocated; 9.28 GiB free; 4.68 MiB cached is displayed.

Could you check the allocated memory via print(torch.cuda.memory_allocated()) inside the training loop and see if the memory usage increases in each iteration?

I printed allocated memory using a batchsize of 4 at every 5 step iteration. It constantly give 0 445350912 5 445350912 10 445350912 15 445350912 20 445350912 25 445350912 30 445350912 35 445350912 40 445350912 45 445350912 50 445350912 55 445350912 60 445350912 65 445350912 70 445350912 75 445350912 80 445350912 85 445350912 90 445350912 95 445350912 100 445350912

This would correspond to approx. only 424MB inside the loop. Did you see any increase in the allocated memory before the OOM error was raised?

No. I did not increase the allocated memory. Can the configuration (num_worker = 0, pin_memory = True)
of the datalaoder or the way of loading my files be responsible for this? The files are approximately 30,000 .npy files in total and 7 files are needed one at a time from the dataloader.

This shouldn’t be the case, as the DataLoader would load the batches into the CPU RAM, if no to('cuda') or cuda() operation is used inside the Dataset.

1 Like

A colleague suggested not loading the data from google drive which was my approach. I copied some of the files to the content drive of Colab (because the total file size is more than the allocated disk space) and I run the code with a batch size of 8 and I got results per epoch.

I appreciate your supports.

Quick one, is there any cloud that you suggest that offers high ram memory and speed with larger disk space that a student can use for his project?

Currently, my model is not training as it gives the same result

Epoch     4: reducing learning rate of group 0 to 2.0000e-02.
Epoch     9: reducing learning rate of group 0 to 4.0000e-03.
Epoch    14: reducing learning rate of group 0 to 8.0000e-04.
Epoch    19: reducing learning rate of group 0 to 1.6000e-04.
Epoch    24: reducing learning rate of group 0 to 3.2000e-05.
Epoch    29: reducing learning rate of group 0 to 6.4000e-06.
Epoch    34: reducing learning rate of group 0 to 1.2800e-06.
Epoch    39: reducing learning rate of group 0 to 2.5600e-07.
Epoch    44: reducing learning rate of group 0 to 5.1200e-08.
Epoch    49: reducing learning rate of group 0 to 1.0240e-08.
Epoch: 1/10 Training | loss1: 0.4706 score1: 0.8874 loss2: 0.8314 score2: 0.0593 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:19:14
Epoch: 2/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0577 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:18:03
Epoch: 3/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0577 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:53
Epoch: 4/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0577 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:52
Epoch: 5/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0577 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:51
Epoch: 6/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0576 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:30
Epoch: 7/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0576 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:23
Epoch: 8/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0576 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:23
Epoch: 9/10 Training | loss1: 0.4688 score1: 0.8931 loss2: 0.8315 score2: 0.0576 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:28
Epoch: 10/10 Training | loss1: 0.4688 score1: 0.8932 loss2: 0.8315 score2: 0.0576 Validation | loss1: 0.4761 score1: 0.8717 loss2: 0.8314 score2: 0.0586 |Time: 0:17:36

Your scheduler is reducing the learning rate such that the training seems to get stalled.
You could remove the scheduler and check if your model is able to learn the train set. Once this is possible, you could then try to make sure it is also able to generalize well to the validation dataset.

1 Like