When using datapipes, it seems like you want to apply the sharding_filter
as early as possible, in order to prevent data loader workers from doing duplicate work.
However, I’m struggling to understand how to use the sharding_filter
in the following scenario:
def create_datapipe(s3_urls):
pipe = dp.IterableWrapper(s3_urls)
# a fake function that downloads from s3
pipe = pipe.map(download_url)
pipe = pipe.flatmap(chunk: [item for item in chunk])
pipe = pipe.batch(batch_size=256)
pipe = pipe.map(do_expensive_computation)
Ideally, I would place the sharding filter before do_expensive_computation
, or before batch
. The issue is; this parallelizes downloading from S3, which I don’t necessarily want.
Ideally, I’d have the downloading from S3 be on a single thread that is accessed from all workers, and then start parallelization after the flatmap
stage.
The reason this matters is–imagine my chunks are big. (E.g, each chunk contains 5K samples), then if the number of chunks is not divisible by the number of workers, one worker will have to process an extra chunk. This worker, will take a very long time.