Dropout (?) causes different model convergence for training+validation and only training

We are facing a very strange issue. We tested the exact same model into two different “execution” settings. In the first case, given a certain amount of epochs, we train using mini-batches for one epoch, and thereafter we test on the validation set following the same criteria. Then, we go for the next epoch. Clearly, before each training epoch, we use model.train(), and before validation we turn on model.eval().

Then we take the exact same model (same init, same dataset, same epochs, etc.) and we just train it without validation after each epoch.

Just looking at performance on training set, we observed that, even if we fixed all seeds, the two training procedures evolve differently and produce quite different metrics results (losses, accuracy, and so on). Specifically, the training-only procedure is less performing.

We also observe the following things:

  • It is not a reproducibility issue, because multiple executions of the same procedure produce exactly the same results (and this is intended);
  • Removing the dropout, it appears that the problem vanishes;
  • Batchnorm1d layer, that still has different behaviours between training and evaluation, seems to work properly;
  • The issue still happens if we move from training onto TPUs to CPUs.

We are working and tried Pythorch 1.6, Pythorch nightly, XLA 1.6.

We quite lost one full day in trying to tackle this issue (and no, we cannot avoid using dropout). Does anyone have any idea about how to solve this fact?

Thank you very much!

p.s. Here the code employed for the training (on CPU).

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))


def _run(model, EPOCHS, training_data_in, validation_data_in=None):
    
    def train_fn(train_dataloader, model, optimizer, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        
        model.train()

        for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):

            optimizer.zero_grad() 
                
            outputs = model(ecg)

            loss = criterion(outputs, labels)
                        
            loss.backward() # calculate the gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step() # update the network weights
                                                
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
            
            running_accuracy += (predicted == labels).sum().item() / labels.size(0)   
            
            fp = ((predicted - labels) == 1.).sum().item() 
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                  'accuracy':running_accuracy / batch_idx,
                  'tp':running_tp,
                  'tn':running_tn,
                  'fp':running_fp,
                  'fn':running_fn
                 }
            
        return retval
            

        
    def valid_fn(valid_dataloader, model, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
 
        model.eval()
        
        for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):

            outputs = model(ecg)

            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model

            running_accuracy += (predicted == labels).sum().item() / labels.size(0)  
            
            fp = ((predicted - labels) == 1.).sum().item()
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                  'accuracy':running_accuracy / batch_idx,
                  'tp':running_tp,
                  'tn':running_tn,
                  'fp':running_fp,
                  'fn':running_fn
                 }
            
        return retval
    
    
    
    # Defining data loaders

    train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
    
    if validation_data_in != None:
        validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


    # Defining the loss function
    criterion = nn.BCEWithLogitsLoss()
    
    
    # Defining the optimizer
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07) 


    # Training code
    
    metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "accuracy_bis":[], "tp":[], "tn":[], "fp":[], "fn":[],
                       "val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_accuracy_bis":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[],}
    
    train_begin = time.time()
    for epoch in range(EPOCHS):
        start = time.time()

        print("EPOCH:", epoch+1)

        train_metrics = train_fn(train_dataloader=train_dataloader, 
                                 model=model,
                                 optimizer=optimizer, 
                                 criterion=criterion)
        
        metrics_history["loss"].append(train_metrics["loss"])
        metrics_history["accuracy"].append(train_metrics["accuracy"])
        metrics_history["tp"].append(train_metrics["tp"])
        metrics_history["tn"].append(train_metrics["tn"])
        metrics_history["fp"].append(train_metrics["fp"])
        metrics_history["fn"].append(train_metrics["fn"])
        
        precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
        recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
        specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
        f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
        metrics_history["precision"].append(precision)
        metrics_history["recall"].append(recall)
        metrics_history["f1"].append(f1)
        metrics_history["specificity"].append(specificity)
        
        
        
        if validation_data_in != None:    
            # Calculate the metrics on the validation data, in the same way as done for training
            with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients

                val_metrics = valid_fn(valid_dataloader=validation_dataloader, 
                                       model=model,
                                       criterion=criterion)

                metrics_history["val_loss"].append(val_metrics["loss"])
                metrics_history["val_accuracy"].append(val_metrics["accuracy"])
                metrics_history["val_tp"].append(val_metrics["tp"])
                metrics_history["val_tn"].append(val_metrics["tn"])
                metrics_history["val_fp"].append(val_metrics["fp"])
                metrics_history["val_fn"].append(val_metrics["fn"])

                val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
                val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
                val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
                val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
                metrics_history["val_precision"].append(val_precision)
                metrics_history["val_recall"].append(val_recall)
                metrics_history["val_f1"].append(val_f1)
                metrics_history["val_specificity"].append(val_specificity)


            print("  > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
            print("  > Training/validation accuracy:", round(train_metrics['accuracy'], 4), round(val_metrics['accuracy'], 4))
            print("  > Training/validation precision:", round(precision, 4), round(val_precision, 4))
            print("  > Training/validation recall:", round(recall, 4), round(val_recall, 4))
            print("  > Training/validation f1:", round(f1, 4), round(val_f1, 4))
            print("  > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
        else:
            print("  > Training loss:", round(train_metrics['loss'], 4))
            print("  > Training accuracy:", round(train_metrics['accuracy'], 4))
            print("  > Training precision:", round(precision, 4))
            print("  > Training recall:", round(recall, 4))
            print("  > Training f1:", round(f1, 4))
            print("  > Training specificity:", round(specificity, 4))


        print("Completed in:", round(time.time() - start, 1), "seconds \n")

    print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")    

    
    
    # Save the model weights
    torch.save(model.state_dict(), './nnet_model.pt')
    
    
    # Save the metrics history
    torch.save(metrics_history, 'training_history')

Did you try to use different seeds without the validation loop and check how the training behaves?
I guess your training might not be very stable so that calling into the pseudo-random number generator during validation would change the next random values and your model would diverge during the training.

Here is an explanation of the RNG behavior during the validation loop, if a DataLoader is used with a workatround. Also here is a similar issue to yours.

1 Like