Custom dataloader and transform errors

I have written a custom dataset class to load an image from a path along with two transform functions as given below:

class TestDataset(torch.utils.data.Dataset):

def __init__(self, root, split, transform=None):
    self.image_path = list()
    for (dirpath, dirnames, filenames) in os.walk(root + split):
        self.image_path += [os.path.join(dirpath, file) for file in filenames if file.endswith(('.jpg','.png'))]
    self.transform = transform

def __getitem__(self,index):
    filename = self.image_path[index]
    with Image.open(filename) as f:
        image = f.convert('RGB')
    image = image.resize((640,480), Image.ANTIALIAS)
    if self.transform != None:
        image = self.transform(image)
    return image

def __len__(self):
    return len(self.image_path)

d_test = TestDataset(datadir, 'test', transform=ttransforms.Compose([ttransforms.ToTensor(), 
ttransforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]))

loader_test = DataLoader(d_test, num_workers=4, batch_size=1, shuffle=False)

When I enumerate over the dataloader, I get the following error:

for iter, img in enumerate(loader_test):
    img = img.cuda()
    model = model.cuda()
    #.........CODE .........#

   Traceback (most recent call last):
   File "test_script.py", line 53, in <module>
   for iter, img in enumerate(loader_test):
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/d$taloader.py", line 363, in __next__
   data = self._next_data()
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/d$taloader.py", line 989, in _next_data
   return self._process_data(data)
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/d$taloader.py", line 1014, in _process_data
   data.reraise()
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/_utils.py", 
   $ine 395, in reraise
   raise self.exc_type(msg)
   TypeError: Caught TypeError in DataLoader worker process 0.
   Original Traceback (most recent call last):
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/_$tils/worker.py", line 185, in _worker_loop
   data = fetcher.fetch(index)
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/_$tils/fetch.py", line 44, in fetch
   data = [self.dataset[idx] for idx in possibly_batched_index]
   File "/home/ashutosh.mishra/miniconda3/envs/pytorch/lib/python3.7/site- 
   packages/torch/utils/data/_$tils/fetch.py", line 44, in <listcomp>
   data = [self.dataset[idx] for idx in possibly_batched_index]
   File "/home/ashutosh.mishra/VPGNet-Pytorch/Exp7_unet/test_dataset.py", line 26, in __getitem__
   image = self.transform(image)
   File "/home/ashutosh.mishra/VPGNet-Pytorch/Exp7_unet/test_transform.py", line 67, in __call__
   args = t(*args)
   TypeError: __call__() takes 2 positional arguments but 4 were given

My transform compose looks like this:

class Compose(object):
"""Composes several transforms together."""

def __init__(self, transforms):
    self.transforms = transforms

def __call__(self, *args):
    for t in self.transforms:
         args = t(*args)
    return args

I am not able to figure out what is the source of the error since we load single images in the dataloader and enumerate on it. Apart from it, the transforms compose is the standard class which is used to compose different transformations.

It’s not completely identical, since you are using *args instead of img, which will unwrap the image tensor.
I.e. the error is raised in the second transformation, which will use args to pass an image tensor in the shape [channels, height, width] to Normalize. Using t(*args) will unwrap the image tensor in dim0 and if it’s containing 3 channels, you’ll get the raised error.

What’s your use case to reimplement transforms.Compose in this way?

Thanks for your reply. I already had the transform file having different transforms hence didn’t use inbuilt torchvision one. I actually checked upon my torchvision version installed and it gives the following error

AttributeError: module 'torchvision.transforms' has no attribute 'Compose' 

But, changing *args to image in the compose function worked as required.
Thanks again!

Which torchvision version are you using? I would have guessed Compose was added pretty early.