What is the best way to apply k-fold cross validation in CNN?

This post seems to give an additional example of how to use cross validation in skorch.

Alternatively you could also directly use sklearn.model_selection methods to create indices for the splits and recreate the datasets via Subsets in each iteration.

1 Like

Thank you very much!!!

Following the topic this link could help a lot to anyone interested

For me, I think a simple loop can do the job, no need to use any other kind of library,

Here is the pseudo code for illustration

results = []
for fold in range(total_fold):
    train_set, test_set = split_dataset(your_dataset, fold)
    model = MyModel()
    results.append(your_training_function(model, train_set, test_set)
print(mean(results))

As we create a new model inside the loop, so there is no need to reset its parameters as well.

2 Likes

This was useful to better understanding it. All it really does is loop over the dataset. If the dataset is random_split already then we can just loop over it a few times. I am sure it adds a little more bias than the correct cross-validation but definitely helps the understanding!

David

I have wrote a function to implement crossvalidation ,it may be help you.

I will just suggest to use a Subset(dataset,index) wrapper to index the dataset loaded using an arranged way. And specify the indexes using some predefined criteria like
“if you have 50 samples and using 5 fold validation then for first case use first 40 indexes for training and use rest for testing”
We can used above criteria to train and test on the desired range of sample calculated for specific fold. save results. use them later.
This method is good when we have longer training time and we can face some interruption during training.
I will request for correction if I am wrong :slight_smile:

Hello ptrblck… Do you have any sample code that split the data into train, test and val and also use stratifedKfold?

You could use the code example from here, which shows how to use sklearn.model_selection.StratifiedKFold.
Once inside the “index loop”, you could create torch.utils.data.Subsets to create the datasets using the drawn indices.
Let me know, if you need more information. :slight_smile:

1 Like

And how the split_dataset function is defined?