Torchvision Datasets

Hello Everyone,

I had a question about Torchvision Datatsets. Are these datasets its own class or is it just a link where you can download the files? I’m trying run SRGAN sample program and it uses torchvision dataset but I wanted to run my own dataset but getting errors. Any help would be greatly appreciated! Thank you!

USING_STL = True
if USING_STL:
DatasetSubclass = torchvision.datasets.STL10
else:
DatasetSubclass = torchvision.datasets.ImageNet

class Dataset(DatasetSubclass):
def init(self, *args, **kwargs):
hr_size = kwargs.pop(‘hr_size’, [96, 96])
lr_size = kwargs.pop(‘lr_size’, [24, 24])
super().init(*args, **kwargs)

How do I replace STL10 or ImageNet dataset with my own dataset?

The torchvision datasets are part of the dataset class but they specifically download their own dataset from the web so you cannot use them for different datasets. If you want to create your own dataset you can look at this tutorial here or if your data is in the correct format where each folder is a class you can use the imagefolder class here.

Thank you! I’ll definitely take a look.