Balance Classes and Random Split

Hello,

I have an ImageFolder object with datapoints of 3 unbalanced classes and I want to randomly choose n points from each class where n is the minimum class count and then split the new dataset formed into a training and validation set (either keeping the proportions or randomly).
I have this code but I am not sure if I am doing it correctly.

dataset = datasets.ImageFolder(image_dir, transform=transformations)
images_label = {image[0]: image[1] for image in dataset.imgs}

class_counts = {}
for image_id in images_label.keys():
    label = images_label[image_id]
    class_counts[label] = class_counts.get(label, 0) + 1

class_weights = list(class_counts.values())
class_weights /= np.sum(class_weights)

sampler = torch.utils.data.sampler.WeightedRandomSampler(class_weights, sum(class_counts.values()))

data_loader = DataLoader(dataset, sampler=sampler)   

train_length=int(0.8* len(data_loader))
test_length=len(data_loader)-train_length

train_dataset,test_dataset = torch.utils.data.random_split(data_loader.dataset,(train_length,test_length))

dataloader_train = torch.utils.data.DataLoader(train_dataset)
dataloader_test = torch.utils.data.DataLoader(test_dataset)

In your current approach it seems that you are using the class_counts to create a WeightedRandomSampler, while the each sample should get a weight as described in this post.

I’m also unsure, if each batch should contain at least n samples from each class or the dataset splits.
In the former case, you could write a custom sampler (and remove the WeightedRandomSampler) such that indices are samples using your condition.