Inter-shard shuffling with torchdata datapipes

Is there a way in torchdata to shuffle data across shards in a multiprocessing setting? In my current pipeline, I shuffle my shards before I apply the sharding_filter to distribute them among the worker processes and then do an intra-shard shuffle on the individual workers. However, this way batches generated by a worker process still contain samples from only one shard. Is there an elegant way to also shuffle the results returned by the individual workers?

My current pipeline looks similar to this

base_dp = (
    dp.iter.FileLister(input_path, masks="*.wds")
    .shuffle()  # shuffle shards
    .sharding_filter()
    .open_files(mode="b")
    .load_from_tar()
    .map(decode_data)
    .webdataset()
    .shuffle(buffer_size=<smaller than shard size>)  # this is on a single worker --> intra-shard shuffle
    .batch(batch_size)
)

My current approach is to define a dummy datapipe which forces the data back into the main process, and to perform the shuffle there.

@dp.functional_datapipe("collect_from_workers")
class WorkerResultCollector(dp.iter.IterDataPipe):
    def __init__(self, source: dp.iter.IterDataPipe):
        self.source = source

    def __iter__(self) -> Iterator:
        yield from self.source

    def is_replicable(self) -> bool:
        """Method to force data back to main process"""
        return False

new_dp = (
    base_dp
    .collect_from_workers()
    .shuffle(buffer_size=<n * size of batch size used in individual workers>, unbatch_level=1)
    .batch(batch size)
)

Any feedback appreciated. Other way to handle sharded data, “official” collector dp to replace my hackish version, …