Hi,
I am trying to perform stratified k-fold cross-validation on a multi-class image classification problem(4 classes) but I have some doubts regarding it.
According to my understanding, we train every fold for a certain number of epochs and then calculate the performance on each fold and average it down and term it as average metric(accuracy or the choice of metric).
I have a doubt that -
1)Do we reset the weights of the model, learning rate, optimiser state for every fold?
2)if Yes(which I highly doubt), then how it is different from a normal hold out method, this is just training the model in different kinds of distribution and as weights are not saved anywhere so this is basically starting from the start for every fold
3)If No, then how to use the same weights, learning rate, optimiser state in each fold?
According to my understanding, learning rate and optimiser state should be changed for every fold while the weights should be continued from the previous folds.
Below is my code, which is using the default first-time weight initialisation of the pre-trained model but the learning rate is being copied from the previous fold which I am unable to understand why? I am assuming, the optimiser is following the same path too as the learning rate.
batch_size = 512
df_train, df_test, splits = cross_validation_train_test(csv_file=file.csv, stratify_colname='labels') # noqa
for fold in range(5):
print("Fold: ", fold)
partition, labels = kfold(df_train, df_test, splits, fold, stratify_columns='labels') # noqa
training_set = Dataset(partition['train_set'], labels, root_dir=root_dir, train_transform=True) # noqa
validation_set = Dataset(partition['val_set'],labels,root_dir=root_dir,valid_transform=True) # noqa
test_set = Dataset(partition['test_set'], labels, root_dir=root_dir,test_transform = None)
train_loader = torch.utils.data.DataLoader(training_set, shuffle=True, pin_memory=True, num_workers=0, batch_size=batch_size) # noqa
val_loader = torch.utils.data.DataLoader(validation_set, shuffle=True, pin_memory=True, num_workers=0, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, shuffle =True, pin_memory=True, num_workers=0, batch_size=batch_size) # noqa
data_transfer = {'train': train_loader,
'valid': val_loader,
'test': test_loader
}
train_model(model=model_transfer, loader = data_transfer, optimizer = optimizer, criterion = criterion_transfer,scheduler=scheduler, n_epochs = 50) # noqa