I have an ImageFolder object with datapoints of 3 unbalanced classes and I want to randomly choose n points from each class where n is the minimum class count and then split the new dataset formed into a training and validation set (either keeping the proportions or randomly).
I have this code but I am not sure if I am doing it correctly.
dataset = datasets.ImageFolder(image_dir, transform=transformations)
images_label = {image[0]: image[1] for image in dataset.imgs}
class_counts = {}
for image_id in images_label.keys():
label = images_label[image_id]
class_counts[label] = class_counts.get(label, 0) + 1
class_weights = list(class_counts.values())
class_weights /= np.sum(class_weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(class_weights, sum(class_counts.values()))
data_loader = DataLoader(dataset, sampler=sampler)
train_length=int(0.8* len(data_loader))
train_dataset,test_dataset = torch.utils.data.random_split(data_loader.dataset,(train_length,test_length))
dataloader_train = torch.utils.data.DataLoader(train_dataset)
dataloader_test = torch.utils.data.DataLoader(test_dataset)