Calculate the average model for kfold cross validation models

Hi, I am trying to calculate the average model for five models generated by k fold cross validation (five folds ) .

I tried the code below but it doesn’t work .

Also,if I run each model separately only the last model is working in our case will be the fifth model (if we have 3 folds will be the third model).

from torch.autograd import Variable
k_folds =5
num_epochs = 5
 
# For fold results
results = {}
 
# Set fixed random number seed
#torch.manual_seed(0)
  
dataset = torch.utils.data.ConcatDataset([image_datasets['train'],])

print(f'<>------------ Dataset length : {len(dataset)} ------------<>')
# Define the K-fold Cross Validator

kfold = StratifiedKFold(n_splits=k_folds, shuffle=True)

# kfold = Kfold(n_splits=k_folds, shuffle=True)
since = time.time()

# K-fold Cross Validation model evaluation
for fold, (train_ids, validate_ids) in enumerate(kfold.split(dataset,image_datasets['train'].targets)):
    
    # Print
    print(f'<>-----------------FOLD {fold+1}-----------------<>')
    print('train: %s, test: %s' % (len(train_ids), len(validate_ids)))
    
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    validate_subsampler = torch.utils.data.SubsetRandomSampler(validate_ids)
    
    
    # Define data loaders for training and validation data in this fold
    
    trainloader = torch.utils.data.DataLoader(
                      dataset, 
                      batch_size=10,  num_workers= 0,sampler=train_subsampler)
    validatloader = torch.utils.data.DataLoader(
                      dataset,
                      batch_size=10, num_workers= 0, sampler=validate_subsampler)
    
    # Init the neural network
    model = MyEnsemble( modelA, modelB, modelC , modelD).to(device)
    #reset layer's weights from previous training
    model.apply(reset_weights)
    
#     # Initialize optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=.001)
    
    # Run the training loop for defined number of epochs
    Training_loss=[]
    Validation_loss=[]
  
    valid_loss_min = np.Inf # track change in validation loss
    
    for epoch in range(1, num_epochs+1):
        print('Epoch: {}  '.format(epoch))
        #print('--------')
        
        # keep track of training and validation loss
        train_loss = 0.0
        valid_loss = 0.0
        #train
        model.train()
        for data, target in trainloader:
        
            if torch.cuda.is_available():
                   data, target = data.cuda(), target.cuda()
            
        # clear the gradients of all optimized variables
            optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
            output = model(Variable(data))
        
        #print(target.shape)
        # calculate the batch loss
            loss = criterion(output, Variable(target))
        # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
        # perform a single optimization step (parameter update)
            optimizer.step()
        # update training loss
            train_loss += loss.item()*data.size(0)
    #validate
        model.eval()
        accuracy=0.0
        correct, total = 0, 0
        with torch.no_grad():
            for data, target in validatloader:
                 data, target = Variable(data), Variable(target)
        
        #data, target = Variable(data), Variable(target)
                 if torch.cuda.is_available():
                         data, target = data.cuda(), target.cuda()
        # forward pass: compute predicted outputs by passing inputs to the model
                 output = model(data)
        
        # calculate the batch loss
                 loss = criterion(output, target)
        # update average validation loss 
                 valid_loss += loss.item()*data.size(0)
        # Set total and correct
                 _, predicted = torch.max(output.data, 1)
                 total += target.size(0)
                 correct += (predicted == target).sum().item()
                    
            results[fold] = 100.0 * (correct / total)
    # calculate average losses
        train_loss = train_loss/len(trainloader.dataset)
        valid_loss = valid_loss/len(validatloader.dataset)
        
        Training_loss.append(train_loss/len(trainloader))
        Validation_loss.append(valid_loss/len(validatloader))
        
        # print training/validation statistics 
        print('Training Loss: {:.6f} \tValidation Loss: {:.6f} \t '.format(
         train_loss, valid_loss))
        
    # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print("==============================================================================================")
            print('Validation loss decreased ({:.6f} --> {:.6f}).  >>>>>>>  Saving model ...'.format(
                   valid_loss_min,
                     valid_loss)) 
            print("==============================================================================================")  
            save_path = f'./EnsembleModelfold-{fold+1}.pt'
            #torch.save(model.state_dict(), save_path)
#             save_path = f'./EnsembleModelfold-{fold+1}.pt'
            torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)
            valid_loss_min = valid_loss
    print('Accuracy for fold %d: %d %%' % (fold+1, 100.0 * correct / total))
#     plt.figure(figsize = (25,10))                
#     plt.plot(Training_loss, label='Training loss')
#     plt.plot(Validation_loss, label='Validation loss')
#     plt.legend(frameon=False)
#     results[fold] = 100.0 * (correct / total)
    
#Print fold results
print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
print('--------------------------------')
sum = 0.0
for key, value in results.items():
    print(f'Fold {key}: {value} %')
    sum += value
print(f'Average: {sum/len(results.items())} %')

time_elapsed = time.time() - since
print('Training completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

AverageModel = MyEnsemble( modelA, modelB, modelC , modelD).to(device)
for key in AverageModel.state_dict():
    AverageModel.state_dict()[key]=AverageModel.state_dict()[key].zero_()

checkpoint_names=['EnsembleModelfold-1.pt','EnsembleModelfold-2.pt','EnsembleModelfold-3.pt','EnsembleModelfold-4.pt','EnsembleModelfold-5.pt']
for name in checkpoint_names:
    checkpoint = torch.load(name)
    fold_model = MyEnsemble( modelA, modelB, modelC , modelD).to(device)
    fold_model.load_state_dict(checkpoint['model_state_dict'])
    for key in AverageModel.state_dict():
        AverageModel.state_dict()[key]+=fold_model.state_dict()[key]
for key in AverageModel.state_dict():
    AverageModel.state_dict()[key]=AverageModel.state_dict()[key]/len(checkpoint_names)
        
AverageModel.state_dict()

What is the actual error that you are getting here?

Thanks for replying,…

When I run the model on the test data set one class gives me 100% Accuracy and the other class gives 0% Accuracy
This happens for first four models , only the last model gives me an acceptable result

Do you notice anything strange about the output of the models? Are they stuck to a single prediction value?
What do the training curves look like for the models?

Here is the output of the training

<>------------ Dataset length : 8583 ------------<>
<>-----------------FOLD 1-----------------<>
train: 6866, test: 1717
Epoch: 1  
Training Loss: 0.486925 	Validation Loss: 0.112415 	 
==============================================================================================
Validation loss decreased (inf --> 0.112415).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 2  
Training Loss: 0.443527 	Validation Loss: 0.106931 	 
==============================================================================================
Validation loss decreased (0.112415 --> 0.106931).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 3  
Training Loss: 0.426966 	Validation Loss: 0.105515 	 
==============================================================================================
Validation loss decreased (0.106931 --> 0.105515).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 4  
Training Loss: 0.415489 	Validation Loss: 0.095385 	 
==============================================================================================
Validation loss decreased (0.105515 --> 0.095385).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 5  
Training Loss: 0.409037 	Validation Loss: 0.097692 	 
Accuracy for fold 1: 76 %
<>-----------------FOLD 2-----------------<>
train: 6866, test: 1717
Epoch: 1  
Training Loss: 0.485727 	Validation Loss: 0.120089 	 
==============================================================================================
Validation loss decreased (inf --> 0.120089).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 2  
Training Loss: 0.445043 	Validation Loss: 0.104871 	 
==============================================================================================
Validation loss decreased (0.120089 --> 0.104871).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 3  
Training Loss: 0.428421 	Validation Loss: 0.120718 	 
Epoch: 4  
Training Loss: 0.416499 	Validation Loss: 0.098932 	 
==============================================================================================
Validation loss decreased (0.104871 --> 0.098932).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 5  
Training Loss: 0.410152 	Validation Loss: 0.098041 	 
==============================================================================================
Validation loss decreased (0.098932 --> 0.098041).  >>>>>>>  Saving model ...
==============================================================================================
Accuracy for fold 2: 77 %
<>-----------------FOLD 3-----------------<>
train: 6866, test: 1717
Epoch: 1  
Training Loss: 0.485534 	Validation Loss: 0.110139 	 
==============================================================================================
Validation loss decreased (inf --> 0.110139).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 2  
Training Loss: 0.448009 	Validation Loss: 0.101339 	 
==============================================================================================
Validation loss decreased (0.110139 --> 0.101339).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 3  
Training Loss: 0.424282 	Validation Loss: 0.098663 	 
==============================================================================================
Validation loss decreased (0.101339 --> 0.098663).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 4  
Training Loss: 0.408519 	Validation Loss: 0.097198 	 
==============================================================================================
Validation loss decreased (0.098663 --> 0.097198).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 5  
Training Loss: 0.407020 	Validation Loss: 0.103193 	 
Accuracy for fold 3: 76 %
<>-----------------FOLD 4-----------------<>
train: 6867, test: 1716
Epoch: 1  
Training Loss: 0.484172 	Validation Loss: 0.113273 	 
==============================================================================================
Validation loss decreased (inf --> 0.113273).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 2  
Training Loss: 0.445774 	Validation Loss: 0.103834 	 
==============================================================================================
Validation loss decreased (0.113273 --> 0.103834).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 3  
Training Loss: 0.427883 	Validation Loss: 0.103864 	 
Epoch: 4  
Training Loss: 0.413481 	Validation Loss: 0.101294 	 
==============================================================================================
Validation loss decreased (0.103834 --> 0.101294).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 5  
Training Loss: 0.404326 	Validation Loss: 0.099928 	 
==============================================================================================
Validation loss decreased (0.101294 --> 0.099928).  >>>>>>>  Saving model ...
==============================================================================================
Accuracy for fold 4: 76 %
<>-----------------FOLD 5-----------------<>
train: 6867, test: 1716
Epoch: 1  
Training Loss: 0.472580 	Validation Loss: 0.105664 	 
==============================================================================================
Validation loss decreased (inf --> 0.105664).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 2  
Training Loss: 0.427865 	Validation Loss: 0.097842 	 
==============================================================================================
Validation loss decreased (0.105664 --> 0.097842).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 3  
Training Loss: 0.408531 	Validation Loss: 0.094644 	 
==============================================================================================
Validation loss decreased (0.097842 --> 0.094644).  >>>>>>>  Saving model ...
==============================================================================================
Epoch: 4  
Training Loss: 0.401408 	Validation Loss: 0.103782 	 
Epoch: 5  
Training Loss: 0.395811 	Validation Loss: 0.091990 	 
==============================================================================================
Validation loss decreased (0.094644 --> 0.091990).  >>>>>>>  Saving model ...
==============================================================================================
Accuracy for fold 5: 78 %
K-FOLD CROSS VALIDATION RESULTS FOR 5 FOLDS
--------------------------------
Fold 0: 76.93651718112989 %
Fold 1: 77.34420500873617 %
Fold 2: 76.41234711706466 %
Fold 3: 76.3986013986014 %
Fold 4: 78.55477855477857 %
Average: 77.12928985206214 %
Training completed in 130m 3s

Ok, so the actual training looks fine but why are you averaging all of the parameters across the models? Do you mean to average the predictions instead?

I want to create the averageModel.pt to use it later

I solved the problem but there is something I can’t understand , in this block of code the problem is solved and the generated average model work as expected

model_1=torch.load('Ensemble_Modelfold-1.pt')
model_2=torch.load('Ensemble_Modelfold-2.pt')
model_3=torch.load('Ensemble_Modelfold-3.pt')
model_4=torch.load('Ensemble_Modelfold-4.pt')
model_5=torch.load('Ensemble_Modelfold-5.pt')

for key in model_1.state_dict():
    model_1.state_dict()[key] = (model_1.state_dict()[key] + model_2.state_dict()[key]+  model_3.state_dict()[key]+model_4.state_dict()[key] + model_5.state_dict()[key] ) / 5
    
Average_Model = MyEnsemble( modelA, modelB, modelC , modelD).to(device)
Average_Model.load_state_dict(model_1.state_dict())

but If we use this block which almost similar to the first one, the generated average model gives me the same issue

model_1=torch.load('Ensemble_Modelfold-1.pt')
model_2=torch.load('Ensemble_Modelfold-2.pt')
model_3=torch.load('Ensemble_Modelfold-3.pt')
model_4=torch.load('Ensemble_Modelfold-4.pt')
model_5=torch.load('Ensemble_Modelfold-5.pt')

sd1=model_1.state_dict()
sd2=model_2.state_dict()
sd3=model_3.state_dict()
sd4=model_4.state_dict()
sd5=model_5.state_dict()

for key in sd1:
    sd1[key] = (sd1[key] + sd2[key] + sd3[key]+sd4[key] + sd5[key] ) / 5
    
Average_Model = MyEnsemble( modelA, modelB, modelC , modelD).to(device)
Average_Model.load_state_dict(sd1)

So, Why this is happening ?