Is there a way in
torchdata to shuffle data across shards in a multiprocessing setting? In my current pipeline, I shuffle my shards before I apply the
sharding_filter to distribute them among the worker processes and then do an intra-shard shuffle on the individual workers. However, this way batches generated by a worker process still contain samples from only one shard. Is there an elegant way to also shuffle the results returned by the individual workers?
My current pipeline looks similar to this
base_dp = ( dp.iter.FileLister(input_path, masks="*.wds") .shuffle() # shuffle shards .sharding_filter() .open_files(mode="b") .load_from_tar() .map(decode_data) .webdataset() .shuffle(buffer_size=<smaller than shard size>) # this is on a single worker --> intra-shard shuffle .batch(batch_size) )
My current approach is to define a dummy
datapipe which forces the data back into the main process, and to perform the shuffle there.
@dp.functional_datapipe("collect_from_workers") class WorkerResultCollector(dp.iter.IterDataPipe): def __init__(self, source: dp.iter.IterDataPipe): self.source = source def __iter__(self) -> Iterator: yield from self.source def is_replicable(self) -> bool: """Method to force data back to main process""" return False new_dp = ( base_dp .collect_from_workers() .shuffle(buffer_size=<n * size of batch size used in individual workers>, unbatch_level=1) .batch(batch size) )
Any feedback appreciated. Other way to handle sharded data, “official” collector dp to replace my hackish version, …