Hi all,
I am trying to load STL10 dataset, by passing a class of transformation. My aim is to obtain two agumeneted versions of each image. But when I pass the class as a transformation argument, I got the message that the class is not iterable. I tried to make the class simpler. But I got the same message. Does it mean that we can not pass a class as transformations in loading datasets?
Here is a sample of code:
color_jitter = transforms.ColorJitter(0.8 , 0.8 , 0.8 , 0.2 )
data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=96),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter],p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
class Trans(object):
def __init__(self,transform):
self.transform = transform
def __cal__(self,sample):
xi = self.transform(sample)
xj = self.transform(sample)
return(xi, xj)
dataset =datasets.STL10(’./data’,split = ‘train+unlabeled’, download = True,
transform =Trans(data_transforms) )
Any help is appreciated.
Best