Single dataloader on datasets with different transforms

I have 2 datasets, in which each has a different sequence of transforms applied e.g., the first transform differs in whether the data needs to be converted to a PILImage. Transforms are applied in their getitem function definitions.

I am able to construct a data loader on each data separately and train properly.

I am trying to construct a single data loader with the 2 datasets merged and interleaved randomly.

I used a ConcatDataset to combine the 2 datasets, and passed this into the dataset loader (where the instances get shuffled), but get the following error after 1 epoch appears to complete successfully. I’m not sure why this error is occurring since the dataloader works when constructed using each dataset separately.

AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/home/.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py”, line 178, in _worker_loop
data = fetcher.fetch(index)
File “/home//.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py”, line 47, in fetch
return self.collate_fn(data)
File “/home//.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py”, line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File “/home//.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py”, line 79, in
return [default_collate(samples) for samples in transposed]
File “/home//.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py”, line 52, in default_collate
numel = sum([x.numel() for x in batch])
File “/home//.conda/envs/di_env/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py”, line 52, in
numel = sum([x.numel() for x in batch])
AttributeError: ‘int’ object has no attribute ‘numel’

One of your datasets seem to return an int where a tensor is expected.
You could add print statements into the Dataset.__getitem__ methods, check the type of all returned objects, and make sure they are tensors.