Hi Folks,
I am implementing a K-fold cross validation for my PyTorch model, but I seem to have a problem with how I am creating the datasets, the transforms and the DataLoaders.
What I want to do: Have all my image volumes (in .nii.gz) split into k-folds to run several trainings on those folds. The transforms for data augmentation (train_transforms) should only be applied to the training data in each fold. Right now, also the validation data gets augmented.
I used several tutorials for this, but these tutorials always used to download MNIST data, where you just toggle the “train”-argument to True for training data, and False for validation data. I assume, my problem lies in coping with these tutorials, but using a custom dataset.
My Code:
First, I define my dataset and the transforms in an extra dataset.py script:
def prepare(data_dir):
# Collect images and labels
images = glob(os.path.join(data_dir, 'images', '*.nii.gz'))
labels = glob(os.path.join(data_dir, 'labels', '*.nii.gz'))
# Split into {test_size} % validation set
train_images, valid_images, train_labels, valid_labels = train_test_split(images,
labels,
test_size=0.2,
shuffle=False)
# Combine images and labels into dictionaries
train_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
valid_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(valid_images, valid_labels)]
# set deterministic training for reproducibility
set_determinism(seed=0)
# Define the data transforms and augmentations (random rotation and flips)
train_transforms = Compose(
[
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']), # Add a channel as 1st element
RandAffined(
keys=['image', 'label'],
mode=('bilinear', 'nearest'),
prob=0.5,
rotate_range=(0.174533, 0.174533), # +-10° in radians
shear_range=(0.2, 0.5)
),
RandFlipd(keys=["image", "label"],
prob=0.5,
spatial_axis=(0, 1)
),
RandRotate90d(keys=["image", "label"],
prob=0.5,
max_k=1,
spatial_axes=(0, 1)
),
ToTensord(keys=['image', 'label'])
]
)
valid_transforms = Compose(
[
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ToTensord(keys=['image', 'label'])
]
)
# Create DataLoader objects, either using cached or non-cached Datasets
train_ds = CacheDataset(data=train_files, transform=train_transforms,cache_rate=1.0)
valid_ds = CacheDataset(data=valid_files, transform=valid_transforms,cache_rate=1.0)
dataset = ConcatDataset([train_ds, valid_ds])
return dataset
Next, in my training script, I configure my k-fold cross validation and fetch my dataset:
from sklearn.model_selection import KFold
# ------------------------K-Fold cross validation----------------------------
# Configure k-fold cross validation
kfolds = 5
kfold = KFold(n_splits=kfolds, shuffle=True, random_state=42)
results = {} # For fold results
dataset = prepare(input)
Then, I start a loop over the dataset, using the kfolds.split function. In this loop, random subsets are being generated from the dataset, and the dataloaders are being initialized:
for fold, (train_ids, valid_ids) in enumerate(kfold.split(dataset)):
MODULE_LOGGER.info(f"Running fold {fold+1}/{kfolds}")
# Sample elements randomly from a given list of ids, no replacement.
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_ids)
# Define data loaders for training and validation data in this fold
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, sampler=train_subsampler
)
valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, sampler=valid_subsampler
)
# Avoid weight leakage by reseting weights for each fold
model.apply(utils.reset_weights)
# ---------------------------------------------------------------------------
# ------------------------------Training Loop--------------------------------
# Start typical Pytorch training Loop....
I am strongly assuming, that my problem lies in the prepare function, where I concatenate the train and validation datasets after splitting into one (which kinda makes the split using sklearn’s train_test_split obsolete). However, I struggle with how to perform the loop over the k-fold, without using a concatenated dataset.
Another thing I noticed: I assume, the dataloaders are being initialized only once, at the beginning of my runtime (I can see progress bars, loading the data). I saw a tutorial from MONAI, where several dataloaders where initialized, depening on the amount of k-folds.
I am confused, on how to run the k-fold CV on a custom dataset, using a typical PyTorch training.
Thanks!