Num classes in dataset

How do I find the num classes in a PyTorch dataset? I expect it to be a certain number but I want to be able to verify after I loaded the data with ImageFolder. I currently only know how to find the length of the dataset.

Can you please check the unique value in the labels?

You can get the classes by using the .classes attribute

For example if you’ve loaded to the train data using ImageFolder:

train_set = datasets.ImageFolder(os.path.join(data_dir, "train"),
                                         data_transforms["train"])

train_set.classes 

Will output the unique values in your training/test/val set

Thank you that works!