How do I update my ImageFolder dataset in pytorch?

I am working on a dataset where I am required to find accuracy of classes with less than 20 samples. So first I used pytorch’s ImageFolder to get all the images in the folders.

dataset = ImageFolder('/content/drive/MyDrive/data/dataset')

Now to get classes with less than 20 samples I use:

def get_class_distribution(dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for element in dataset_obj:
        y_lbl = element[1]
        y_lbl = idx2class[y_lbl]
        count_dict[y_lbl] += 1
            
    return count_dict
# print("Distribution of classes: \n", get_class_distribution(dataset))
class_distribution = get_class_distribution(dataset)

sampled_classes = [classes  for (classes, samples) in class_distribution.items() if samples <= 20]

I get the list of classes correctly but my doubt is how do I proceed further for inference? How do I convert/update it to ImageFolder so that I can use the filtered dataset in the following code:

# Test model performance for classes with less than 20 samples.

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        _, y_pred_tag = torch.max(y_test_pred, dim = 1)
        y_pred_list.append(y_pred_tag.cpu().numpy())
        y_true_list.append(y_batch.cpu().numpy())

Depending on how much data there is, it might be simpler to just filter out data elements at inference time (with a moderate penalty on performance due to smaller batch sizes).

# sketch of idea, might need refinement depending on specific shape
def filter_data(x_batch, y_batch, skip_classes):
    for class in skip_classes:
        y_batch[y_batch==class] = -1
    mask = y_batch >= 0
    filtered_x_batch = x_batch[mask,:]
    filtered_y_batch = y_batch[mask]
    return filtered_x_batch, filtered_y_batch
...
# Test model performance for classes with less than 20 samples.

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        x_batch, y_batch = filter_data(x_batch, y_batch, skip_classes)
        y_test_pred = model(x_batch)
        _, y_pred_tag = torch.max(y_test_pred, dim = 1)
        y_pred_list.append(y_pred_tag.cpu().numpy())
        y_true_list.append(y_batch.cpu().numpy())