Is there a way in torchdata
to shuffle data across shards in a multiprocessing setting? In my current pipeline, I shuffle my shards before I apply the sharding_filter
to distribute them among the worker processes and then do an intra-shard shuffle on the individual workers. However, this way batches generated by a worker process still contain samples from only one shard. Is there an elegant way to also shuffle the results returned by the individual workers?
My current pipeline looks similar to this
base_dp = (
dp.iter.FileLister(input_path, masks="*.wds")
.shuffle() # shuffle shards
.sharding_filter()
.open_files(mode="b")
.load_from_tar()
.map(decode_data)
.webdataset()
.shuffle(buffer_size=<smaller than shard size>) # this is on a single worker --> intra-shard shuffle
.batch(batch_size)
)
My current approach is to define a dummy datapipe
which forces the data back into the main process, and to perform the shuffle there.
@dp.functional_datapipe("collect_from_workers")
class WorkerResultCollector(dp.iter.IterDataPipe):
def __init__(self, source: dp.iter.IterDataPipe):
self.source = source
def __iter__(self) -> Iterator:
yield from self.source
def is_replicable(self) -> bool:
"""Method to force data back to main process"""
return False
new_dp = (
base_dp
.collect_from_workers()
.shuffle(buffer_size=<n * size of batch size used in individual workers>, unbatch_level=1)
.batch(batch size)
)
Any feedback appreciated. Other way to handle sharded data, “official” collector dp to replace my hackish version, …