Hello all,
I have noticed some to me unexpected behaviour when using ConcatDataset
. I managed to reproduce it with the minimal example provided below.
When shuffle=True
is removed the script runs fine without any problems. Furthermore, the prints show that the sizes of all the elements in the concatenated dataset are the same.
import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
import torchvision
from torchvision import transforms
nb_samples = 100
features = torch.randn(nb_samples, 1, 28, 28)
labels = torch.empty(nb_samples, dtype=torch.long).random_(10)
dataset = TensorDataset(features, labels)
DATASET_DIR = '../../data/datasets'
transform = transforms.Compose([transforms.ToTensor()])
mnist = torchvision.datasets.MNIST(root=DATASET_DIR, train=True, download=True, transform=transform)
set = ConcatDataset([dataset, mnist])
loader = DataLoader(
set,
batch_size=10, shuffle=True
)
for batch_idx, (x, y) in enumerate(loader):
print(x.shape, y.shape)
Now if I include shuffle=True
I get the a TypeError with the following stacktrace:
Traceback (most recent call last):
File "test.py", line 23, in <module>
for batch_idx, (x, y) in enumerate(loader):
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
data = self._next_data()
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/Users/MyUser/opt/anaconda3/envs/conda-env-name/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
TypeError: expected Tensor as element 1 in argument 0, but got int
Is this expected behaviour, i.e. is it somehow not possible to use shuffling with the ConcatDataset?
With some debugging I did, it looks like some batches are transformed into tuples of values, where some are scalar tensors and some values are just int, the reason not being quite clear to me.
According to python -c "import torch; print(torch.__version__)"
I am on pytorch version 1.5.0
Help with this would be greatly appreciated.