K-fold cross validation on custom cataset

Hi Folks,

I am implementing a K-fold cross validation for my PyTorch model, but I seem to have a problem with how I am creating the datasets, the transforms and the DataLoaders.
What I want to do: Have all my image volumes (in .nii.gz) split into k-folds to run several trainings on those folds. The transforms for data augmentation (train_transforms) should only be applied to the training data in each fold. Right now, also the validation data gets augmented.
I used several tutorials for this, but these tutorials always used to download MNIST data, where you just toggle the “train”-argument to True for training data, and False for validation data. I assume, my problem lies in coping with these tutorials, but using a custom dataset.

My Code:

First, I define my dataset and the transforms in an extra dataset.py script:

def prepare(data_dir):
        # Collect images and labels
        images = glob(os.path.join(data_dir, 'images', '*.nii.gz'))
        labels = glob(os.path.join(data_dir, 'labels', '*.nii.gz'))

        # Split into {test_size} % validation set
        train_images, valid_images, train_labels, valid_labels = train_test_split(images, 

        # Combine images and labels into dictionaries
        train_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
        valid_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(valid_images, valid_labels)]

        # set deterministic training for reproducibility

        # Define the data transforms and augmentations (random rotation and flips) 
        train_transforms = Compose(
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']), # Add a channel as 1st element
                keys=['image', 'label'],
                mode=('bilinear', 'nearest'),
                rotate_range=(0.174533, 0.174533), # +-10° in radians
                shear_range=(0.2, 0.5)
            RandFlipd(keys=["image", "label"], 
                spatial_axis=(0, 1)
            RandRotate90d(keys=["image", "label"], 
                spatial_axes=(0, 1)
            ToTensord(keys=['image', 'label'])

        valid_transforms = Compose(
            LoadImaged(keys=['image', 'label']), 
            AddChanneld(keys=['image', 'label']), 
            ToTensord(keys=['image', 'label'])

        # Create DataLoader objects, either using cached or non-cached Datasets
        train_ds = CacheDataset(data=train_files, transform=train_transforms,cache_rate=1.0)
        valid_ds = CacheDataset(data=valid_files, transform=valid_transforms,cache_rate=1.0)

        dataset = ConcatDataset([train_ds, valid_ds])

        return dataset

Next, in my training script, I configure my k-fold cross validation and fetch my dataset:

from sklearn.model_selection import KFold

# ------------------------K-Fold cross validation----------------------------
# Configure k-fold cross validation
kfolds = 5
kfold = KFold(n_splits=kfolds, shuffle=True, random_state=42)
results = {} # For fold results

dataset = prepare(input)

Then, I start a loop over the dataset, using the kfolds.split function. In this loop, random subsets are being generated from the dataset, and the dataloaders are being initialized:

for fold, (train_ids, valid_ids) in enumerate(kfold.split(dataset)):
        MODULE_LOGGER.info(f"Running fold {fold+1}/{kfolds}")

        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_ids)
        # Define data loaders for training and validation data in this fold
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, sampler=train_subsampler
        valid_loader = torch.utils.data.DataLoader(
            dataset, batch_size=1, sampler=valid_subsampler
        # Avoid weight leakage by reseting weights for each fold
        # ---------------------------------------------------------------------------
        # ------------------------------Training Loop--------------------------------

        # Start typical Pytorch training Loop....

I am strongly assuming, that my problem lies in the prepare function, where I concatenate the train and validation datasets after splitting into one (which kinda makes the split using sklearn’s train_test_split obsolete). However, I struggle with how to perform the loop over the k-fold, without using a concatenated dataset.
Another thing I noticed: I assume, the dataloaders are being initialized only once, at the beginning of my runtime (I can see progress bars, loading the data). I saw a tutorial from MONAI, where several dataloaders where initialized, depening on the amount of k-folds.

I am confused, on how to run the k-fold CV on a custom dataset, using a typical PyTorch training.


You need to create the datasets inside the k-fold for loop so that you can pass the train_transform and val_transform to the right split.

1 Like

Dear @rinkujadhav2013,

thanks alot for your reply. I implemented this ad it works fine.

Have a great day!