IterableDataset - why is len required?

Hello,

I am having trouble with an IterableDataset. Pytorch wants me to tell it the len of my datasource. But as it’s an iterable datasource, why would that matter? If I can generate data at least as fast as I consume it, the datasource could be functionally infinite.

My situation is that I have data in many files. I have enough memory to hold any individual file, but not all of them at once. My current approach is below. In short: the iterable hands out data stored in memory. When the iterable runs out, it refreshes by loading data from the next file into memory.

If I set len to 0, it doesn’t load anything. So, it checks the len before loading? That seems fundamentally opposed to what Iterables are.

If I leave len as 1, things seem to work, but I get several annoying warnings:

UserWarning: Length of IterableDataset <somestuff.MyDataSet object at 0x000002348AD5AF08> was reported to be 1 (when accessing len(dataloader)), but 2 samples have been fetched. warnings.warn(warn_msg)

If I set len to some very large number, things also seem to work and the warnings go away. I feel gross about it tho.

Why do I have to provide the len at all?

I found https://github.com/pytorch/pytorch/pull/23587, but that answer is unsatisfactory to me.

My code:

class MyDataSet(IterableDataset):
    def __init__(self, sample_dir):
        self.sample_dir = sample_dir
        self.len = 1

    def __len__(self):
        return self.len

    def get_sample(self):
        for f in [self.sample_dir + '/' + filename for filename in os.listdir(self.sample_dir)]:
            data = torch.load(f)
            for item in data:
                yield item

    def __iter__(self):
        return self.get_sample()

...

    dataset = MyDataSet('MyDir/samples')
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=1,
        num_workers=0
    )
1 Like