I am performing multi label image classification. I am using DataLoaders for this. How can I see the breakdown of the number of training and test images present for each class?
I assume your labels are saved in a one-hot encoded format for a multi label classification.
If that’s the case, you could iterate your Dataset
once and just count all class occurrences:
class MyDataset(Dataset):
def __init__(self, num_classes):
self.data = torch.randn(100, 2)
self.target = torch.empty(100, num_classes, dtype=torch.long).random_(2)
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
return x, y
def __len__(self):
return len(self.data)
num_classes = 10
dataset = MyDataset(num_classes)
labels = torch.zeros(num_classes, dtype=torch.long)
for _, target in dataset:
labels += target
4 Likes
Thanks a lot. Seems easy after seeing the solution.
1 Like