I am performing multi label image classification. I am using DataLoaders for this. How can I see the breakdown of the number of training and test images present for each class?
I assume your labels are saved in a one-hot encoded format for a multi label classification.
If that’s the case, you could iterate your
Dataset once and just count all class occurrences:
class MyDataset(Dataset): def __init__(self, num_classes): self.data = torch.randn(100, 2) self.target = torch.empty(100, num_classes, dtype=torch.long).random_(2) def __getitem__(self, index): x = self.data[index] y = self.target[index] return x, y def __len__(self): return len(self.data) num_classes = 10 dataset = MyDataset(num_classes) labels = torch.zeros(num_classes, dtype=torch.long) for _, target in dataset: labels += target
Thanks a lot. Seems easy after seeing the solution.