How do I find the num classes in a PyTorch dataset? I expect it to be a certain number but I want to be able to verify after I loaded the data with ImageFolder. I currently only know how to find the length of the dataset.
Can you please check the unique value in the labels?
You can get the classes by using the
For example if you’ve loaded to the train data using ImageFolder:
train_set = datasets.ImageFolder(os.path.join(data_dir, "train"), data_transforms["train"]) train_set.classes
Will output the unique values in your training/test/val set
Thank you that works!