Custom MNIST Dataset error, traceback to vision.py

I’m writing my own MNIST Dataset from datasets.MNIST, hoping to only transform classes 2,3 and 4.

train_transform =  transforms.Compose([transforms.Pad(2),
                                transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

aug_transform = transforms.Compose([transforms.Pad(2),
                                transforms.RandomRotation(20),
                                transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

class myMNIST(datasets.MNIST):
    def __init__(self,root,train,download,transform):
        #print(root)
        #print(transform)
        super(datasets.MNIST,self).__init__(root, train,download,
                                            transform)
        print(root)
        print(transform)
        self.transform = train_transform

    def __getitem__(self, index):

        if (y == 2 or y==3 or y ==4): 
            x = self.transform()
        
        return x, y
    def transform(self): 
        self.transform = aug_transform

        
train_dataset_special_aug = myMNIST(root='data', 
                                    train=True,
                                    download=True, 
                                 transform= train_transform)

The error is as follows:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-153-e7ddc431947b> in <module>
     31                                     train=True,
     32                                     download=True,
---> 33                                  transform= train_transform)
     34 
     35 #print(train_dataset_special_aug)

<ipython-input-153-e7ddc431947b> in __init__(self, root, train, download, transform)
     13         #print(transform)
     14         super(datasets.MNIST,self).__init__(root, train,download,
---> 15                                             transform)
     16         print(root)
     17         print(transform)

~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchvision/datasets/vision.py in __init__(self, root, transforms, transform, target_transform)
     15         has_separate_transform = transform is not None or target_transform is not None
     16         if has_transforms and has_separate_transform:
---> 17             raise ValueError("Only transforms or transform/target_transform can "
     18                              "be passed as argument")
     19 

ValueError: Only transforms or transform/target_transform can be passed as argument

I looked at vision.py, and it seems to me that I should never reach this line, as transforms==None hence has_transforms == False

I’d really appreciate any lead/help, thank you.

You are passing the arguments in the wrong order and are also calling the __ini__ function of the VisionDataset, while I assume you would like to call it from the parent class i.e. MNIST.
This code should work:

class myMNIST(datasets.MNIST):
    def __init__(self,root,train,download,transform):
        super(myMNIST,self).__init__(root, train,transform, download=download)
        self.transform = train_transform

    def __getitem__(self, index):
        if (y == 2 or y==3 or y ==4): 
            x = self.transform()
        return x, y

    def transform(self): 
        self.transform = aug_transform

    
train_dataset_special_aug = myMNIST(root='data', 
                                    train=True,
                                    transform= train_transform,
                                    download=True)

Thank you for the prompt reply!

I realized that error later too. In the end, I was able to transform the classes as wanted, but the code seems really redundant (sorry I’m a newbie!).

Is there a way that I process the conditional transform in the transform() method instead of the __getitem__() method?

Thank you!


ori_transform = transforms.Compose([transforms.Pad(2),
                                transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])
aug_transform = transforms.Compose([
                                transforms.Pad(2),
                                transforms.RandomAffine(degrees=45, translate=(0.1, 0.1), scale=(0.8, 1.2)),
                                transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])



class myMNIST(datasets.MNIST):
    def __init__(self,
                 root,
                 train, 
                 download,
                 transform,):


        super().__init__(root, train, None, None,download)

        self.train= train
        self.transform = transform
        self.download = download
        
        

    def __getitem__(self, index):
        
        x = self.data[index]
        y = self.targets[index]
        aug_classes = [2,3,5,6,7,9]
        if (y.item() in aug_classes):
            

            x = transforms.ToPILImage()(x)           
            x = self.transform(x)

        else:

            x = transforms.ToPILImage()(x)
            x= ori_transform(x)

        return x, y
    
    def transform(self): 
        
        self.transform = aug_transform
        
     

train_dataset_special_aug = myMNIST(root='data', 
                                    train=True,
                                    download=True, 
                                    transform = aug_transform,
                                   )

I think using the conditional transformation in the __getitem__ is the clean way to do it and would not suggest to try to fit it into a transformation class somehow.

1 Like