Loading a dataset by passing a class of trasnforamtions

Hi all,

I am trying to load STL10 dataset, by passing a class of transformation. My aim is to obtain two agumeneted versions of each image. But when I pass the class as a transformation argument, I got the message that the class is not iterable. I tried to make the class simpler. But I got the same message. Does it mean that we can not pass a class as transformations in loading datasets?

Here is a sample of code:

color_jitter = transforms.ColorJitter(0.8 , 0.8 , 0.8 , 0.2 )

data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=96),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([color_jitter],p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])

class Trans(object):

def __init__(self,transform):
    
    self.transform = transform
    
def __cal__(self,sample):
    
    xi = self.transform(sample)
    
    xj = self.transform(sample)
    
    return(xi, xj)

dataset =datasets.STL10(’./data’,split = ‘train+unlabeled’, download = True,
transform =Trans(data_transforms) )

Any help is appreciated.

Best

Could you post the complete error message with the stack trace, please?
Also, you have a typo in self.__call__ :wink: .

Thank you for your answer. The problem was solved. It was just due to the typo in the self.__call__ definition.:slight_smile: