Type error in ConcatDataset when using a BatchSampler

I’m using ConcatDataset to combine multiple (custom) torch.Dataset instances. Then I want to sample batches from the concatenation. Because I use a custom data structure I want to use the __getitems__-call, and not the single getitem for every sample and collate to combine. The following code shows my attempt, moving along the solutions of multiple other posts here in this forum.

import torch
from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, Dataset, SequentialSampler

DS_SIZE = 2
BATCH_SIZE = 5


class MyDataset(Dataset):
    def __init__(self, start: int) -> None:
        assert start >= 0
        self.data = torch.arange(start=start, end=start + DS_SIZE)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        print(f"single - idx: {idx}")
        return self.data[idx]

    def __getitems__(self, indices: list[int]) -> torch.Tensor:
        print(f"batch - indices: {indices}")
        return self.data[indices]


if __name__ == "__main__":
    dataset = ConcatDataset([MyDataset(start=DS_SIZE * i) for i in range(10)])

    dl1 = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)

    for batch in dl1:
        print(f"1: {batch}")

    dl2 = DataLoader(
        dataset=dataset,
        sampler=BatchSampler(SequentialSampler(dataset), batch_size=BATCH_SIZE, drop_last=False),
        batch_size=None,
    )

    for batch in dl2:
        print(f"2: {batch}")

As expected does the first DataLoader use the __getitem__ and a (custom) collate route. But the second DataLoader fails due to the ConcatDataset trying to compare the batched list indices against a number without checking the type of the index first. See the code here. The step before that, in the fetch-call, it seems like a possibly_batched_index is still expected, but not so in ConcatDataset. Is this a Bug or am I doing something wrong?

Python: 3.10
PyTorch: 2.2.0+cu121

The error message:

  File ... 
    for batch in dl2:
  File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File ".../venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 53, in fetch
    data = self.dataset[possibly_batched_index]
  File ".../venv/lib/python3.10/site-packages/torch/utils/data/dataset.py", line 326, in __getitem__
    if idx < 0:
TypeError: '<' not supported between instances of 'list' and 'int'
1 Like

The same for me! Unable to use ConcatDataset with BatchSamles. The same error occurs