I got an error when making a class with Dataset as a parent.
the error message is :
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in
45 num_workers=0)
46
---> 47 for batch_idx, (data, target, idx) in enumerate(loader):
48 print('Batch idx {}, dataset index {}'.format(
49 batch_idx, idx))
~\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
613 if self.num_workers == 0: # same-process loading
614 indices = next(self.sample_iter) # may raise StopIteration
--> 615 batch = self.collate_fn([self.dataset[i] for i in indices])
616 if self.pin_memory:
617 batch = pin_memory_batch(batch)
~\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py in (.0)
613 if self.num_workers == 0: # same-process loading
614 indices = next(self.sample_iter) # may raise StopIteration
--> 615 batch = self.collate_fn([self.dataset[i] for i in indices])
616 if self.pin_memory:
617 batch = pin_memory_batch(batch)
d:\Projects\pytorch\forgetting-examples-mixup\dataloaders.py in __getitem__(self, index)
24
25 def __getitem__(self, index):
---> 26 data, target = self.ds[index]
27
28 # Your transformations here (or set it in CIFAR10)
~\Anaconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index)
75
76 if self.transform is not None:
---> 77 img = self.transform(img)
78
79 if self.target_transform is not None:
TypeError: object() takes no parameters
my class :
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
dsname = "mnist"
class IndexedDataset(Dataset):
"""
this class can return the index of the example in the dataset
for example-forgetting indexing purpose (saving forgetting statistics)
"""
def __init__(self):
self.ds = None
if dsname == "mnist":
self.ds = datasets.MNIST(root='data/mnist',
download=True,
train=True,
transform=transforms.ToTensor)
elif dsname == "cifar10":
self.ds = datasets.CIFAR10(root='data/cifar10',
download=True,
train=True,
transform=transforms.ToTensor)
else :
raise Exception('dsname must be "mnist" or "cifar10", dsname was: {}'.format(dsname))
def __getitem__(self, index):
data, target = self.ds[index]
return data, target, index
def __len__(self):
return len(self.ds)
the error raised when I run :
dataset = dataloaders.IndexedDataset()
loader = DataLoader(dataset,
batch_size=1,
shuffle=True,
num_workers=0)
for batch_idx, (data, target, idx) in enumerate(loader):
print('Batch idx {}, dataset index {}'.format(
batch_idx, idx))
did i miss something ?