My training data is split into multiple files so I have to use ConcatDataset, then I split them into train/test/valid sets using random_split.
I don’t think this is a very uncommon situation but I cannot get this to work since one wants an iteratable dataset and the other wants an indexable dataset, so I end up with something the works in neither situation. Both
Should add that I am using a custom IteratableDataset that also implements indexing.
Unfortunately I cannot fix this by joining the raw files and loading them as one since the data is interpreted as sequences and joining them would create invalid sequences right where two files meet.
It’s only 9 months later
But in any case: could you describe how IterableDatasets should be concatenated? They don’t have a length, since they are used for e.g streaming data, so I would assume that you could just switch the stream once it’s exhausted?
@ptrblck I want to use 2 IterableDatasets for each epoch. As you said, once the iterator reaches the end of the first IterableDataset, I want to use the second ‘IterableDataset’. How do I achieve this?
You can just continue with the second IterableDataset as seen here:
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
ds1 = MyIterableDataset(start=3, end=7)
loader1 = DataLoader(ds1, num_workers=1)
ds2 = MyIterableDataset(start=7, end=10)
loader2 = DataLoader(ds2, num_workers=1)
for epoch in range(2):
for data in loader1:
print(data)
# train with first loader
for data in loader2:
print(data)
# train with second loader
ChainDataset would also work and you can pick any approach that fits your needs.
I don’t know how the initial request would work with IterableDatasets as they are apparently streaming on one hand but also support indexing. The latter would be needed for the splits I would assume, but also the author didn’t follow up so I don’t know how it was solved.