I am working on a dataset where I am required to find accuracy of classes with less than 20 samples. So first I used pytorch’s ImageFolder to get all the images in the folders.
dataset = ImageFolder('/content/drive/MyDrive/data/dataset')
Now to get classes with less than 20 samples I use:
def get_class_distribution(dataset_obj):
count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
for element in dataset_obj:
y_lbl = element[1]
y_lbl = idx2class[y_lbl]
count_dict[y_lbl] += 1
return count_dict
# print("Distribution of classes: \n", get_class_distribution(dataset))
class_distribution = get_class_distribution(dataset)
sampled_classes = [classes for (classes, samples) in class_distribution.items() if samples <= 20]
I get the list of classes correctly but my doubt is how do I proceed further for inference? How do I convert/update it to ImageFolder so that I can use the filtered dataset in the following code:
# Test model performance for classes with less than 20 samples.
y_pred_list = []
y_true_list = []
with torch.no_grad():
for x_batch, y_batch in tqdm(data_loader):
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
y_test_pred = model(x_batch)
_, y_pred_tag = torch.max(y_test_pred, dim = 1)
y_pred_list.append(y_pred_tag.cpu().numpy())
y_true_list.append(y_batch.cpu().numpy())