Implementing an "infinite loop" Dataset & DataLoader combo

I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

As you can see, the main challenge here is the len() method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.

If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.

I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.

Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)

1 Like

A solution worked for me was making a generator function using itertools.repeat.

from itertools import repeat

def repeater(data_loader):
    for loader in repeat(data_loader):
        for data in loader:
            yield data

Then

data_loader = DataLoader(dataset, ...)
data_loader = repeater(data_loader)

for data in data_loader:
    # train your model
4 Likes

As I ran into this as well, and likely many more will, this is my solution:
If you look at torch.utils.data.dataloader._BaseDataLoaderIter` you will find:

def _next_index(self):
    return next(self._sampler_iter)  # may raise StopIteration

so a StopIteration is raised if your Sampler does so. You will also find in the DataLoader class itself:

    def __len__(self):
        if self._dataset_kind == _DatasetKind.Iterable:
            # (...) cut the comment here
            length = self._IterableDataset_len_called = len(self.dataset)
            return length
        else:
            return len(self._index_sampler)

that the length is defined by the _index_sampler if you do not have an IterableDataset. I think that the most elegant solution here is to have an infinite data sampler with no length (as the length is not required in the dataloader, see the comments in the code.

Something like this would be a good sampler:

class InfiniteSampler:
    def __init__(self, dataset_size):
        self.dataset_size = dataset_size

    def __iter__(self):
        yield from itertools.islice(self._infinite(), 0, None, 1) # Infinite iterator

    def _infinite(self):
        g = torch.Generator()
        while True:
            yield from torch.randperm(self.dataset_size, generator=g)

A similar approach is currently taken in Detectron2.

2 Likes

Thanks! This is indeed the approach in detectron2. Just a kind reminder, there should probably be a seeding mechanism in the init and _infinite (something like the one in detectron2), otherwise the sampler will most likely generate the same pseudo-random sequence at each run. :slightly_smiling_face:

islice is not needed. We can just return the generator (which is also an iterator); I think this version is easier to understand:

class InfiniteRandomSampler(data.Sampler):
    """ Return random indices from [0-n) infinitely.
    
    Arguments:
        dset_size (int): Size of the dataset to sample.
    """
    def __init__(self, dset_size):
        self.dset_size = dset_size

    def __iter__(self):
        # Create a random number generator (optional, makes the sampling independent of the base RNG)
        rng = torch.Generator()
        seed = torch.empty((), dtype=torch.int64).random_().item()
        rng.manual_seed(seed)
        
        return _infinite_generator(self.dset_size, rng)
    
    def __len__(self):
        return float('inf')
        
def _infinite_generator(n, rng):
    """ Inifinitely returns a number in [0, n)."""
    while True:
        yield from torch.randperm(n, generator=rng).tolist()

This will create an infinite dataloder but it won’t be “stateful”, i.e., if you stop sampling from it and then start sampling again, the sampling will start anew and forget which examples were already returned in the previous iterations. If you don’t want that, you’ll have to also use what @SunQpark proposed (which also could drop the dependency on the itertools):

def make_infinite(dloader):
    while True:
        yield from dloader