I’m using ConcatDataset
to combine multiple (custom) torch.Dataset
instances. Then I want to sample batches from the concatenation. Because I use a custom data structure I want to use the __getitems__
-call, and not the single getitem for every sample and collate to combine. The following code shows my attempt, moving along the solutions of multiple other posts here in this forum.
import torch
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, Dataset, SequentialSampler
DS_SIZE = 2
BATCH_SIZE = 5
class MyDataset(Dataset):
def __init__(self, start: int) -> None:
assert start >= 0
self.data = torch.arange(start=start, end=start + DS_SIZE)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> torch.Tensor:
print(f"single - idx: {idx}")
return self.data[idx]
def __getitems__(self, indices: list[int]) -> torch.Tensor:
print(f"batch - indices: {indices}")
return self.data[indices]
if __name__ == "__main__":
dataset = ConcatDataset([MyDataset(start=DS_SIZE * i) for i in range(10)])
dl1 = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
for batch in dl1:
print(f"1: {batch}")
dl2 = DataLoader(
dataset=dataset,
sampler=BatchSampler(SequentialSampler(dataset), batch_size=BATCH_SIZE, drop_last=False),
batch_size=None,
)
for batch in dl2:
print(f"2: {batch}")
As expected does the first DataLoader use the __getitem__
and a (custom) collate route. But the second DataLoader fails due to the ConcatDataset trying to compare the batched list indices against a number without checking the type of the index first. See the code here. The step before that, in the fetch-call, it seems like a possibly_batched_index
is still expected, but not so in ConcatDataset. Is this a Bug or am I doing something wrong?
Python: 3.10
PyTorch: 2.2.0+cu121
The error message:
File ...
for batch in dl2:
File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File ".../venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 53, in fetch
data = self.dataset[possibly_batched_index]
File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataset.py", line 326, in __getitem__
if idx < 0:
TypeError: '<' not supported between instances of 'list' and 'int'