I am having trouble with an IterableDataset. Pytorch wants me to tell it the len of my datasource. But as it’s an iterable datasource, why would that matter? If I can generate data at least as fast as I consume it, the datasource could be functionally infinite.
My situation is that I have data in many files. I have enough memory to hold any individual file, but not all of them at once. My current approach is below. In short: the iterable hands out data stored in memory. When the iterable runs out, it refreshes by loading data from the next file into memory.
If I set len to 0, it doesn’t load anything. So, it checks the len before loading? That seems fundamentally opposed to what Iterables are.
If I leave len as 1, things seem to work, but I get several annoying warnings:
UserWarning: Length of IterableDataset <somestuff.MyDataSet object at 0x000002348AD5AF08> was reported to be 1 (when accessing len(dataloader)), but 2 samples have been fetched. warnings.warn(warn_msg)
If I set len to some very large number, things also seem to work and the warnings go away. I feel gross about it tho.
Why do I have to provide the len at all?
I found https://github.com/pytorch/pytorch/pull/23587, but that answer is unsatisfactory to me.
class MyDataSet(IterableDataset): def __init__(self, sample_dir): self.sample_dir = sample_dir self.len = 1 def __len__(self): return self.len def get_sample(self): for f in [self.sample_dir + '/' + filename for filename in os.listdir(self.sample_dir)]: data = torch.load(f) for item in data: yield item def __iter__(self): return self.get_sample() ... dataset = MyDataSet('MyDir/samples') data_loader = DataLoader( dataset=dataset, batch_size=1, num_workers=0 )