TypeError: object() takes no parameters while iterating trough datasets

I got an error when making a class with Dataset as a parent.
the error message is :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
 in 
     45                         num_workers=0)
     46 
---> 47         for batch_idx, (data, target, idx) in enumerate(loader):
     48                 print('Batch idx {}, dataset index {}'.format(
     49                         batch_idx, idx))

~\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    613         if self.num_workers == 0:  # same-process loading
    614             indices = next(self.sample_iter)  # may raise StopIteration
--> 615             batch = self.collate_fn([self.dataset[i] for i in indices])
    616             if self.pin_memory:
    617                 batch = pin_memory_batch(batch)

~\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py in (.0)
    613         if self.num_workers == 0:  # same-process loading
    614             indices = next(self.sample_iter)  # may raise StopIteration
--> 615             batch = self.collate_fn([self.dataset[i] for i in indices])
    616             if self.pin_memory:
    617                 batch = pin_memory_batch(batch)

d:\Projects\pytorch\forgetting-examples-mixup\dataloaders.py in __getitem__(self, index)
     24 
     25     def __getitem__(self, index):
---> 26         data, target = self.ds[index]
     27 
     28         # Your transformations here (or set it in CIFAR10)

~\Anaconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index)
     75 
     76         if self.transform is not None:
---> 77             img = self.transform(img)
     78 
     79         if self.target_transform is not None:

TypeError: object() takes no parameters

my class :

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

dsname = "mnist"
class IndexedDataset(Dataset):
    """
    this class can return the index of the example in the dataset
    for example-forgetting indexing purpose (saving forgetting statistics)
    """
    def __init__(self):
        self.ds = None
        if dsname == "mnist":
            self.ds = datasets.MNIST(root='data/mnist',
                                        download=True,
                                        train=True,
                                        transform=transforms.ToTensor)
        elif dsname == "cifar10":
            self.ds = datasets.CIFAR10(root='data/cifar10',
                                        download=True,
                                        train=True,
                                        transform=transforms.ToTensor)
        else :
            raise Exception('dsname must be "mnist" or "cifar10", dsname was: {}'.format(dsname))

    def __getitem__(self, index):
        data, target = self.ds[index]
        return data, target, index

    def __len__(self):
        return len(self.ds)

the error raised when I run :

dataset = dataloaders.IndexedDataset()

        loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=True,
                        num_workers=0)

        for batch_idx, (data, target, idx) in enumerate(loader):
                print('Batch idx {}, dataset index {}'.format(
                        batch_idx, idx))

did i miss something ?

Hi,

The transform should be an instance, not the class and so should be transforms.ToTensor(). Check the mnist example here.

3 Likes

works like a charm, thanks !