Split dataset with some conditions

Hey there :slight_smile:

I have to split a training dataset into training and validation with a ratio of 80/20, but with some conditions. Let me first introduce the dataset. I am using a medical dataset which consists of the following attributes: patient_id, image_name, target and some other features, however they are not important for that moment. Each patient_id is unique, and there exists a mapping from one patient_id to multiple image_names. Each patient has at least two image_names but at most 120 image_names. Now I want to split the training dataset into training and validation such that if a patient belongs to train or / and val there are at least two images of that patient in that dataset. I think I have a solution, but this one is quite slow and has too many conditional statements. Furthermore, this solution does not have a stratified option, yet.

Here is what I implemented so far:

def split_meta_file_to_train_and_val(path_dataset : str, ratio_train : float = 0.8, ratio_val : float = 0.2, shuffle : bool = True, seed : int = 44):
    if ratio_train+ratio_val != 1:
        raise ValueError(f"The sum of the parameters ratio_train and ratio_val should be 1. However, the sum is actually {ratio_train+ratio_val}.")
    meta = pd.read_csv(os.path.join(path_dataset, 'train.csv'))
    if shuffle == True:
        meta = meta.sample(frac=1.0, random_state = seed).reset_index()
        meta['old_index'] = meta['index']
        meta = meta.drop(columns=['index'])
    length_dataset = len(meta)
    train_length = round(len(meta)*ratio_train)
    val_length = round(len(meta)*ratio_val)
    np.random.seed(42)
    train_idx = []
    val_idx = []
    meta_grouped = meta.groupby(['patient_id']).count().reset_index()
    meta_grouped['in_list'] = 0
    for i, patient in enumerate(meta.values):
        # Check whether the patient has exactly three images and if so whether it fits into the train subset w.r.t. to the split ratio.
        # We need to do so because if two images goes to the train or val subset then we are missing one image in the other subset
        # subset since the model needs two images from one patient. 
        if  (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values == 3) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'].values == 0) and (len(train_idx)+3 <= train_length): 
            train_idx.append(patient[0])
            patient_images = meta.loc[(meta['patient_id'] == patient[1]) & (meta['image_name'] != patient[0])] # Get all other images from that same patient
            train_idx.append(patient_images['image_name'].ravel())
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'] = 1 # Set in_list to one since we have images from that particular patient within the train subset
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] = meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] - 3 # Subtract the amount of images such that we can track the amount of images which are left from that patient
            continue
        # Same as the one above with the difference that we are checking whether images of that patient already allocated to train_idx
        elif (patient[0] not in train_idx) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values == 3) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'].values == 1) and (len(train_idx)+3 <= train_length): 
            train_idx.append(patient[0])
            patient_images = meta.loc[(meta['patient_id'] == patient[1]) & (meta['image_name'] != patient[0]) & (meta[~meta['image_name'].isin(train_idx)])]
            images = np.random.choice(patient_images['image_name'].to_numpy().ravel(), size=2)
            train_idx.append(images.ravel())
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'] = 1
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] = meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] - 3
            continue
        # Check whether we have two or >= 4 images of that patients left and no image of that patient were already allocated into train_idx. Take two images of that patient and allocate it to train_idx when it fits
        elif (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'].values == 0) and (len(train_idx)+2 <= train_length) and ((meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values == 2) or (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values >= 4)):
            train_idx.append(patient[0])
            patient_images = meta.loc[(meta['patient_id'] == patient[1]) & (meta['image_name'] != patient[0]) & (meta[~meta['image_name'].isin(train_idx)])]
            image = np.random.choice(patient_images['image_name'].to_numpy().ravel())            
            train_idx.append(image)
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'] = 1
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] = meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] - 2
            continue
        elif (patient[0] not in train_idx) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'].values == 1) and (len(train_idx)+1 <= train_length) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values >= 3):
            train_idx.append(patient[0])
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] = meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] -1 
            continue
        elif (patient[0] not in train_idx) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'in_list'].values == 1) and (len(train_idx)+2 <= train_length) and (meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'].values == 2):
            train_idx.append(patient[0])
            patient_images = meta.loc[(meta['patient_id'] == patient[1]) & (meta['image_name'] != patient[0]) & (meta[~meta['image_name'].isin(train_idx)])]
            images = np.random.choice(patient_images['image_name'].to_numpy().ravel(), size=2)
            train_idx.append(images.ravel()) 
            meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] = meta_grouped.loc[meta_grouped['patient_id'] == patient[1], 'image_name'] -2
            continue

    train_image_names = meta[meta['image_name'].isin(train_idx)]
    val_image_names = meta[~meta['image_name'].isin(train_idx)]
    return train_image_names, val_image_names

Perhaps there exists an easier way to do so. If so, I would appreciate it if someone can give me a hint / tip.

Best regards :smiley: