Finding number of samples per class in multi label classification

I am performing multi label image classification. I am using DataLoaders for this. How can I see the breakdown of the number of training and test images present for each class?

I assume your labels are saved in a one-hot encoded format for a multi label classification.
If that’s the case, you could iterate your Dataset once and just count all class occurrences:

class MyDataset(Dataset):
    def __init__(self, num_classes):
        self.data = torch.randn(100, 2)
        self.target = torch.empty(100, num_classes, dtype=torch.long).random_(2)
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x, y
        
    def __len__(self):
        return len(self.data)


num_classes = 10    
dataset = MyDataset(num_classes)
labels = torch.zeros(num_classes, dtype=torch.long)

for _, target in dataset:
    labels += target
3 Likes

Thanks a lot. Seems easy after seeing the solution.

1 Like