TypeError when concatenating datasets and shuffling

Hello all,

I have noticed some to me unexpected behaviour when using ConcatDataset. I managed to reproduce it with the minimal example provided below.
When shuffle=True is removed the script runs fine without any problems. Furthermore, the prints show that the sizes of all the elements in the concatenated dataset are the same.

import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import torchvision
from torchvision import transforms

nb_samples = 100
features = torch.randn(nb_samples, 1, 28, 28)
labels = torch.empty(nb_samples, dtype=torch.long).random_(10)

dataset = TensorDataset(features, labels)
DATASET_DIR = '../../data/datasets'

transform = transforms.Compose([transforms.ToTensor()])

mnist = torchvision.datasets.MNIST(root=DATASET_DIR, train=True, download=True, transform=transform)

set = ConcatDataset([dataset, mnist])

loader = DataLoader(
    set,
    batch_size=10, shuffle=True
)
for batch_idx, (x, y) in enumerate(loader):
    print(x.shape, y.shape)

Now if I include shuffle=True I get the a TypeError with the following stacktrace:

Traceback (most recent call last):
  File "test.py", line 23, in <module>
    for batch_idx, (x, y) in enumerate(loader):
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
TypeError: expected Tensor as element 1 in argument 0, but got int

Is this expected behaviour, i.e. is it somehow not possible to use shuffling with the ConcatDataset?
With some debugging I did, it looks like some batches are transformed into tuples of values, where some are scalar tensors and some values are just int, the reason not being quite clear to me.

According to python -c "import torch; print(torch.__version__)" I am on pytorch version 1.5.0

Help with this would be greatly appreciated.

It seems that this line of code is causing the error, as the target will be transformed to an int.

As a workaround you could pass a target_transform and transform the target back to a tensor via:

mnist = torchvision.datasets.MNIST(
    root=DATASET_DIR, train=True, download=True, transform=transform,
    target_transform=transforms.Lambda(lambda y: torch.tensor(y)))
1 Like

Thanks for having a look ptrblck!
And thank you for the workaround, that indeed solved my problem.
I didn’t consider putting a transform on the target.