I have written a custom dataset class to load an image from a path along with two transform functions as given below:
class TestDataset(torch.utils.data.Dataset):
def __init__(self, root, split, transform=None):
self.image_path = list()
for (dirpath, dirnames, filenames) in os.walk(root + split):
self.image_path += [os.path.join(dirpath, file) for file in filenames if file.endswith(('.jpg','.png'))]
self.transform = transform
def __getitem__(self,index):
filename = self.image_path[index]
with Image.open(filename) as f:
image = f.convert('RGB')
image = image.resize((640,480), Image.ANTIALIAS)
if self.transform != None:
image = self.transform(image)
return image
def __len__(self):
return len(self.image_path)
d_test = TestDataset(datadir, 'test', transform=ttransforms.Compose([ttransforms.ToTensor(),
ttransforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]))
loader_test = DataLoader(d_test, num_workers=4, batch_size=1, shuffle=False)
When I enumerate over the dataloader, I get the following error:
for iter, img in enumerate(loader_test):
img = img.cuda()
model = model.cuda()
#.........CODE .........#
Traceback (most recent call last):
File "test_script.py", line 53, in <module>
for iter, img in enumerate(loader_test):
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/d$taloader.py", line 363, in __next__
data = self._next_data()
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/d$taloader.py", line 989, in _next_data
return self._process_data(data)
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/d$taloader.py", line 1014, in _process_data
data.reraise()
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/_utils.py",
$ine 395, in reraise
raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/_$tils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/_$tils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-
packages/torch/utils/data/_$tils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ashutosh.mishra/VPGNet-Pytorch/Exp7_unet/test_dataset.py", line 26, in __getitem__
image = self.transform(image)
File "/home/ashutosh.mishra/VPGNet-Pytorch/Exp7_unet/test_transform.py", line 67, in __call__
args = t(*args)
TypeError: __call__() takes 2 positional arguments but 4 were given
My transform compose looks like this:
class Compose(object):
"""Composes several transforms together."""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *args):
for t in self.transforms:
args = t(*args)
return args
I am not able to figure out what is the source of the error since we load single images in the dataloader and enumerate on it. Apart from it, the transforms compose is the standard class which is used to compose different transformations.