Hello,
I am having trouble with an IterableDataset. Pytorch wants me to tell it the len of my datasource. But as it’s an iterable datasource, why would that matter? If I can generate data at least as fast as I consume it, the datasource could be functionally infinite.
My situation is that I have data in many files. I have enough memory to hold any individual file, but not all of them at once. My current approach is below. In short: the iterable hands out data stored in memory. When the iterable runs out, it refreshes by loading data from the next file into memory.
If I set len to 0, it doesn’t load anything. So, it checks the len before loading? That seems fundamentally opposed to what Iterables are.
If I leave len as 1, things seem to work, but I get several annoying warnings:
UserWarning: Length of IterableDataset <somestuff.MyDataSet object at 0x000002348AD5AF08> was reported to be 1 (when accessing len(dataloader)), but 2 samples have been fetched. warnings.warn(warn_msg)
If I set len to some very large number, things also seem to work and the warnings go away. I feel gross about it tho.
Why do I have to provide the len at all?
I found https://github.com/pytorch/pytorch/pull/23587, but that answer is unsatisfactory to me.
My code:
class MyDataSet(IterableDataset):
def __init__(self, sample_dir):
self.sample_dir = sample_dir
self.len = 1
def __len__(self):
return self.len
def get_sample(self):
for f in [self.sample_dir + '/' + filename for filename in os.listdir(self.sample_dir)]:
data = torch.load(f)
for item in data:
yield item
def __iter__(self):
return self.get_sample()
...
dataset = MyDataSet('MyDir/samples')
data_loader = DataLoader(
dataset=dataset,
batch_size=1,
num_workers=0
)